##### Subscription

Subscribe to our newsletter and receive a selection of cool articles every week

# Image Segmentation using K means clustering algorithm | Python

Ina previous article, we saw how to implement the K means algorithm from scratch in python. We delved deep into the working of the algorithm and discussed some possible practical applications. In this tutorial, we are going to see one such application at work. In this tutorial, we will see how we can use K-means clustering to separate an image into segments based on its pixel values.

If you are new to machine learning or K-means, you can read the original article here.

Since we have discussed all the nitty-gritty details in the original article, I will keep this one fairly short and simple.

This tutorial will be split into 3 short portions.

• Preprocessing step for the image files
• Working of the Algorithm
• Results

Let’s begin!!

```.wp-block-code {
border: 0;
}

.wp-block-code > div {
overflow: auto;
}

.shcb-language {
border: 0;
clip: rect(1px, 1px, 1px, 1px);
-webkit-clip-path: inset(50%);
clip-path: inset(50%);
height: 1px;
margin: -1px;
overflow: hidden;
position: absolute;
width: 1px;
word-wrap: normal;
word-break: normal;
}

.hljs {
box-sizing: border-box;
}

.hljs.shcb-code-table {
display: table;
width: 100%;
}

.hljs.shcb-code-table > .shcb-loc {
color: inherit;
display: table-row;
width: 100%;
}

.hljs.shcb-code-table .shcb-loc > span {
display: table-cell;
}

.wp-block-code code.hljs:not(.shcb-wrap-lines) {
white-space: pre;
}

.wp-block-code code.hljs.shcb-wrap-lines {
white-space: pre-wrap;
}

.hljs.shcb-line-numbers {
border-spacing: 0;
counter-reset: line;
}

.hljs.shcb-line-numbers > .shcb-loc {
counter-increment: line;
}

.hljs.shcb-line-numbers .shcb-loc > span {
}

.hljs.shcb-line-numbers .shcb-loc::before {
border-right: 1px solid #ddd;
content: counter(line);
display: table-cell;
text-align: right;
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
white-space: nowrap;
width: 1%;
}
```import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm```Code language: JavaScript (javascript)```
``````#loading and normalise the image

#Since openCV only supports upto float32 pixel values, we have to typecast our data
img = img.astype(np.float32)

#to preserve the original color, we have to load the image in RGB sincwe by default OpenCV loads it in BGR format
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#rescaling the image is required to reduce the data size and speed up ther algorithm

scale_percent = 20 # percent of original size
width = int(img.shape * scale_percent / 100)
height = int(img.shape * scale_percent / 100)
dim = (width, height)
# using the openCV resize function to appropriately resize the image
resized = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)```Code language: PHP (php)```

We cannot deal with images the same way as we dealt with normal data points, this is because colored images are a 3-dimensional matrix (The third dimension is the RGB color channels and does not add to the perspective but this is important for our mathematics to work).

Each channel consists of pixels containing information regarding the channel color (Red, Green, or Blue)

We need to convert this into 2-dimensional data.

``````#convert the MxNx3 image to a Kx3 image where k = MxN
vectorised = img.reshape((-1,3))
#Convert the array to a dataframe
img_df = pd.DataFrame(vectorised)
img_df.rename(columns={0:'R', 1:'G', 2: 'B'}, inplace =True)```Code language: PHP (php)```

We now have the flattened data in a data frame. It is time to write the algorithm. The Algorithm will remain the same as the original one before, for an in-depth look into K-means clustering, read the original article here.

``````k = 5
diff = 1
j=0

while(abs(diff)>0.05):
XD=X
i=1
#iterate over each centroid point
for index1,row_c in centroids.iterrows():
ED=[]
#iterate over each data point
print("Calculating distance")
for index2,row_d in tqdm(XD.iterrows()):
#calculate distance between current point and centroid
d1=(row_c["R"]-row_d["R"])**2
d2=(row_c["G"]-row_d["G"])**2
d3=(row_c["B"]-row_d["B"])**2
d=np.sqrt(d1+d2+d3)
#append disstance in a list 'ED'
ED.append(d)
#append distace for a centroid in original data frame
X[i]=ED
i=i+1

C=[]
print("Getting Centroid")
for index,row in tqdm(X.iterrows()):
#get distance from centroid of current data point
min_dist=row
pos=1
#loop to locate the closest centroid to current point
for i in range(k):
#if current distance is greater than that of other centroids
if row[i+1] < min_dist:
#the smaller distanc becomes the minimum distance
min_dist = row[i+1]
pos=i+1
C.append(pos)
#assigning the closest cluster to each data point
X["Cluster"]=C
#grouping each cluster by their mean value to create new centroids
centroids_new = X.groupby(["Cluster"]).mean()[["R","G", "B"]]
if j == 0:
diff=1
j=j+1
else:
#check if there is a difference between old and new centroids
diff = (centroids_new['R'] - centroids['R']).sum() + (centroids_new['G'] - centroids['G']).sum() + (centroids_new['B'] - centroids['B']).sum()
print(diff.sum())
centroids = X.groupby(["Cluster"]).mean()[["R","G","B"]]```Code language: PHP (php)```

The algorithm is run for 5 clusters; as denoted by the variable ‘k’. Instead of waiting for the centroids to reach a stable value, this time we have set a threshold to ensure we have optimal centroids.

This is done because, for image data, the algorithm takes a very long time to find the perfect centroid, so we give it a small margin i.e. we say that if the difference between the new and the old centroids is within a certain range, quit the loop and keep the last found centroids.

It is important to note at this point that the algorithm we have written above is quite slow in its working since speed and efficiency are not the goals of this tutorial, here we are only looking at the working of the K-means algorithm. One way to improve the speed of the algorithm is to use the “itertuples” function instead of the “iterrows” since tuples are at least ten times faster in traversal due to their static memory allocation. If you know any other ways to improve the efficiency of this algorithm, do comment down below.

After a little wait, the algorithm has finished its work, lets see what centroids we have in the end result.

``````centroids = centroids.to_numpy()
print(centroids)```Code language: PHP (php)```

Now that we have the centroids, all we need to do is plot the result. The plotting is done in such a way that all the pixels in a particular cluster are overwritten by the centroid of that cluster.

Let’s see what the final image looks like.

``````labels = X["Cluster"].to_numpy()

#overwritting the pixels values
segmented_image = centroids[labels-1]
segmented_image = segmented_image.reshape(img.shape)

#plotting the image
plt.imshow(segmented_image)```Code language: PHP (php)```

### Conclusion

The final image has only 5 colors in total (due to 5 clusters), these 5 colors represent the major colors that were present in the original image.

The final image looks like something out of an Instagram filter. We can use the centroids and the clusters obtained to create multiple masks for the image and separate it into portions, e.g. we can see that the body of the car is red but split into 2 different tones, We can train the algorithm for a fewer number of clusters and then separate the red body of the car altogether since it will part of a single cluster.

Similarly, we can find multiple use cases for image segmentation in this way.

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Prev
Multi-variable Linear Regression Python Implementation. ## Multi-variable Linear Regression Python Implementation.

Linear Regression is the most basic Machine Learning Algorithm and understanding

Next
Implementing K Means Clustering with K Means++ Initialization | Python. ## Implementing K Means Clustering with K Means++ Initialization | Python.

K-Means clustering is an unsupervised machine learning algorithm

##### You May Also Like    