Ina previous article, we saw how to implement the K means algorithm from scratch in python. We delved deep into the working of the algorithm and discussed some possible practical applications. In this tutorial, we are going to see one such application at work. In this tutorial, we will see how we can use K-means clustering to separate an image into segments based on its pixel values.

If you are new to machine learning or K-means, you can read the original article here.

The complete code used in this article can be found here.

Since we have discussed all the nitty-gritty details in the original article, I will keep this one fairly short and simple.

This tutorial will be split into 3 short portions.

  • Preprocessing step for the image files
  • Working of the Algorithm
  • Results

Let’s begin!!

Loading our image will be our first step.

We will use the following image downloaded from unsplash.com

Image of a red car
image credits: Josh Rinard (https://unsplash.com/@joshrinard)
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
#loading and normalise the image
img = cv2.imread("josh-rinard-H9mp1P1VUj4-unsplash.jpg")/255

#Since openCV only supports upto float32 pixel values, we have to typecast our data 
img = img.astype(np.float32)

#to preserve the original color, we have to load the image in RGB sincwe by default OpenCV loads it in BGR format
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#rescaling the image is required to reduce the data size and speed up ther algorithm

scale_percent = 20 # percent of original size
width = int(img.shape[1] * scale_percent / 100)
height = int(img.shape[0] * scale_percent / 100)
dim = (width, height)  
# using the openCV resize function to appropriately resize the image
resized = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)

We cannot deal with images the same way as we dealt with normal data points, this is because colored images are a 3-dimensional matrix (The third dimension is the RGB color channels and does not add to the perspective but this is important for our mathematics to work).

RGB peresentation of an image
Dimensional representation of an image

Each channel consists of pixels containing information regarding the channel color (Red, Green, or Blue)

We need to convert this into 2-dimensional data.

#convert the MxNx3 image to a Kx3 image where k = MxN
vectorised = img.reshape((-1,3))
#Convert the array to a dataframe
img_df = pd.DataFrame(vectorised)
img_df.rename(columns={0:'R', 1:'G', 2: 'B'}, inplace =True)

We now have the flattened data in a data frame. It is time to write the algorithm. The Algorithm will remain the same as the original one before, for an in-depth look into K-means clustering, read the original article here.

k = 5
diff = 1
j=0

while(abs(diff)>0.05):
    XD=X
    i=1
    #iterate over each centroid point 
    for index1,row_c in centroids.iterrows():
        ED=[]
        #iterate over each data point
        print("Calculating distance")
        for index2,row_d in tqdm(XD.iterrows()):
            #calculate distance between current point and centroid
            d1=(row_c["R"]-row_d["R"])**2
            d2=(row_c["G"]-row_d["G"])**2
            d3=(row_c["B"]-row_d["B"])**2
            d=np.sqrt(d1+d2+d3)
            #append disstance in a list 'ED'
            ED.append(d)
        #append distace for a centroid in original data frame
        X[i]=ED
        i=i+1

    C=[]
    print("Getting Centroid")
    for index,row in tqdm(X.iterrows()):
        #get distance from centroid of current data point
        min_dist=row[1]
        pos=1
        #loop to locate the closest centroid to current point
        for i in range(k):
            #if current distance is greater than that of other centroids
            if row[i+1] < min_dist:
                #the smaller distanc becomes the minimum distance 
                min_dist = row[i+1]
                pos=i+1
        C.append(pos)
    #assigning the closest cluster to each data point
    X["Cluster"]=C
    #grouping each cluster by their mean value to create new centroids
    centroids_new = X.groupby(["Cluster"]).mean()[["R","G", "B"]]
    if j == 0:
        diff=1
        j=j+1
    else:
        #check if there is a difference between old and new centroids
        diff = (centroids_new['R'] - centroids['R']).sum() + (centroids_new['G'] - centroids['G']).sum() + (centroids_new['B'] - centroids['B']).sum()
        print(diff.sum())
    centroids = X.groupby(["Cluster"]).mean()[["R","G","B"]]

The algorithm is run for 5 clusters; as denoted by the variable ‘k’. Instead of waiting for the centroids to reach a stable value, this time we have set a threshold to ensure we have optimal centroids.

This is done because, for image data, the algorithm takes a very long time to find the perfect centroid, so we give it a small margin i.e. we say that if the difference between the new and the old centroids is within a certain range, quit the loop and keep the last found centroids.

It is important to note at this point that the algorithm we have written above is quite slow in its working since speed and efficiency are not the goals of this tutorial, here we are only looking at the working of the K-means algorithm. One way to improve the speed of the algorithm is to use the “itertuples” function instead of the “iterrows” since tuples are at least ten times faster in traversal due to their static memory allocation. If you know any other ways to improve the efficiency of this algorithm, do comment down below.

After a little wait, the algorithm has finished its work, lets see what centroids we have in the end result.

centroids = centroids.to_numpy()
print(centroids)
Centroid locations
Final centroids found

Now that we have the centroids, all we need to do is plot the result. The plotting is done in such a way that all the pixels in a particular cluster are overwritten by the centroid of that cluster.

Let’s see what the final image looks like.

labels = X["Cluster"].to_numpy()

#overwritting the pixels values
segmented_image = centroids[labels-1]
segmented_image = segmented_image.reshape(img.shape)

#plotting the image
plt.imshow(segmented_image)
Segmented image
Final image

Conclusion

The final image has only 5 colors in total (due to 5 clusters), these 5 colors represent the major colors that were present in the original image.

The final image looks like something out of an Instagram filter. We can use the centroids and the clusters obtained to create multiple masks for the image and separate it into portions, e.g. we can see that the body of the car is red but split into 2 different tones, We can train the algorithm for a fewer number of clusters and then separate the red body of the car altogether since it will part of a single cluster.

Similarly, we can find multiple use cases for image segmentation in this way.

Leave a Reply

Your email address will not be published. Required fields are marked *