Image Segmentation using K means clustering algorithm | Python

Ina previous article, we saw how to implement the K means algorithm from scratch in python. We delved deep into the working of the algorithm and discussed some possible practical applications. In this tutorial, we are going to see one such application at work. In this tutorial, we will see how we can use K-means clustering to separate an image into segments based on its pixel values.

If you are new to machine learning or K-means, you can read the original article here.

The complete code used in this article can be found here.

Since we have discussed all the nitty-gritty details in the original article, I will keep this one fairly short and simple.

This tutorial will be split into 3 short portions.

  • Preprocessing step for the image files
  • Working of the Algorithm
  • Results

Let’s begin!!

Loading our image will be our first step.

We will use the following image downloaded from

Image of a red car
image credits: Josh Rinard (
import pandas as pd import numpy as np import cv2 import matplotlib.pyplot as plt from tqdm import tqdm
Code language: JavaScript (javascript)
#loading and normalise the image img = cv2.imread("josh-rinard-H9mp1P1VUj4-unsplash.jpg")/255 #Since openCV only supports upto float32 pixel values, we have to typecast our data img = img.astype(np.float32) #to preserve the original color, we have to load the image in RGB sincwe by default OpenCV loads it in BGR format img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #rescaling the image is required to reduce the data size and speed up ther algorithm scale_percent = 20 # percent of original size width = int(img.shape[1] * scale_percent / 100) height = int(img.shape[0] * scale_percent / 100) dim = (width, height) # using the openCV resize function to appropriately resize the image resized = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
Code language: PHP (php)

We cannot deal with images the same way as we dealt with normal data points, this is because colored images are a 3-dimensional matrix (The third dimension is the RGB color channels and does not add to the perspective but this is important for our mathematics to work).

RGB peresentation of an image
Dimensional representation of an image

Each channel consists of pixels containing information regarding the channel color (Red, Green, or Blue)

We need to convert this into 2-dimensional data.

#convert the MxNx3 image to a Kx3 image where k = MxN vectorised = img.reshape((-1,3)) #Convert the array to a dataframe img_df = pd.DataFrame(vectorised) img_df.rename(columns={0:'R', 1:'G', 2: 'B'}, inplace =True)
Code language: PHP (php)

We now have the flattened data in a data frame. It is time to write the algorithm. The Algorithm will remain the same as the original one before, for an in-depth look into K-means clustering, read the original article here.

k = 5 diff = 1 j=0 while(abs(diff)>0.05): XD=X i=1 #iterate over each centroid point for index1,row_c in centroids.iterrows(): ED=[] #iterate over each data point print("Calculating distance") for index2,row_d in tqdm(XD.iterrows()): #calculate distance between current point and centroid d1=(row_c["R"]-row_d["R"])**2 d2=(row_c["G"]-row_d["G"])**2 d3=(row_c["B"]-row_d["B"])**2 d=np.sqrt(d1+d2+d3) #append disstance in a list 'ED' ED.append(d) #append distace for a centroid in original data frame X[i]=ED i=i+1 C=[] print("Getting Centroid") for index,row in tqdm(X.iterrows()): #get distance from centroid of current data point min_dist=row[1] pos=1 #loop to locate the closest centroid to current point for i in range(k): #if current distance is greater than that of other centroids if row[i+1] < min_dist: #the smaller distanc becomes the minimum distance min_dist = row[i+1] pos=i+1 C.append(pos) #assigning the closest cluster to each data point X["Cluster"]=C #grouping each cluster by their mean value to create new centroids centroids_new = X.groupby(["Cluster"]).mean()[["R","G", "B"]] if j == 0: diff=1 j=j+1 else: #check if there is a difference between old and new centroids diff = (centroids_new['R'] - centroids['R']).sum() + (centroids_new['G'] - centroids['G']).sum() + (centroids_new['B'] - centroids['B']).sum() print(diff.sum()) centroids = X.groupby(["Cluster"]).mean()[["R","G","B"]]
Code language: PHP (php)

The algorithm is run for 5 clusters; as denoted by the variable ‘k’. Instead of waiting for the centroids to reach a stable value, this time we have set a threshold to ensure we have optimal centroids.

This is done because, for image data, the algorithm takes a very long time to find the perfect centroid, so we give it a small margin i.e. we say that if the difference between the new and the old centroids is within a certain range, quit the loop and keep the last found centroids.

It is important to note at this point that the algorithm we have written above is quite slow in its working since speed and efficiency are not the goals of this tutorial, here we are only looking at the working of the K-means algorithm. One way to improve the speed of the algorithm is to use the “itertuples” function instead of the “iterrows” since tuples are at least ten times faster in traversal due to their static memory allocation. If you know any other ways to improve the efficiency of this algorithm, do comment down below.

After a little wait, the algorithm has finished its work, lets see what centroids we have in the end result.

centroids = centroids.to_numpy() print(centroids)
Code language: PHP (php)
Centroid locations
Final centroids found

Now that we have the centroids, all we need to do is plot the result. The plotting is done in such a way that all the pixels in a particular cluster are overwritten by the centroid of that cluster.

Let’s see what the final image looks like.

labels = X["Cluster"].to_numpy() #overwritting the pixels values segmented_image = centroids[labels-1] segmented_image = segmented_image.reshape(img.shape) #plotting the image plt.imshow(segmented_image)
Code language: PHP (php)
Segmented image
Final image


The final image has only 5 colors in total (due to 5 clusters), these 5 colors represent the major colors that were present in the original image.

The final image looks like something out of an Instagram filter. We can use the centroids and the clusters obtained to create multiple masks for the image and separate it into portions, e.g. we can see that the body of the car is red but split into 2 different tones, We can train the algorithm for a fewer number of clusters and then separate the red body of the car altogether since it will part of a single cluster.

Similarly, we can find multiple use cases for image segmentation in this way.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Multi-variable Linear Regression Python Implementation.
Linear Regression in Python

Multi-variable Linear Regression Python Implementation.

Linear Regression is the most basic Machine Learning Algorithm and understanding

Implementing K Means Clustering with K Means++ Initialization | Python.

Implementing K Means Clustering with K Means++ Initialization | Python.

K-Means clustering is an unsupervised machine learning algorithm

You May Also Like