KNN Classifier Python Implementation from scratch (96.6% Accuracy)| Machine Learning


For beginners, the terminology “Machine Learning” seems something very complicated and difficult. There is no doubt that it is one of the most rapidly developing fields but that doesn’t mean it has to be too complex. In this tutorial, we will be looking at a very simple, yet useful algorithm called the “K-Nearest Neighbor Algorithm”.

We have all heard the quote:

“you are defined by the company you keep”

KNN takes this literally 😁. This will be clearer when we look at the algorithm.

The entire code used in this tutorial can be found here.

For more Machine Learning articles, click below;

Understanding the Algorithm:

KNN is a supervised algorithm i.e., it requires a labeled training dataset to work. Let’s create a story for ease of understanding. Below we can see that we have 3 different settlements (3 different data points; red, green, purple).

A representation of finely clustered data points
A representation of finely clustered data points

OMG, a wild blob appears!! (a test data point). It looks lost, What will we do now? 😲

A blob in the wild
A blob in the wild

Fear not, for we know the KNN algorithm.

We just calculate the distance (Euclidean distance in mathematics terms) of the wild blob from every house in each settlement.

distances from each cluster

!! Remember we need to calculate the distance with every other data point, the illustration shows fewer lines because perhaps the illustrator was ‘ehm ehm’ a little lazy 🙃 !!

Now, all we do is select the N closest points to our blob. (N here is a hyperparameter i.e., a number which we must optimally decide ourselves)

picking the closest points

Now let’s see which settlement exists most amongst the closest N points. The Red settlement has more points in the vicinity of our wild blob so the wild blob now becomes part of the Red settlement (It is given the label; Red).

clustered assigned

We do this for every test data point. And that is it, that is the algorithm.

Okay, storytime is over, let’s get to coding.

Let us import the necessary libraries.

#Importing required library import pandas as pd import numpy as np from collections import Counter
Code language: PHP (php)

The dataset we will use for this demo is called the Iris dataset. This is an opensource dataset that can be found on the following link: Iris Flower Dataset | Kaggle

Let us explore this dataset.

We have a total of 4 input features and the name of the flower category as our output labels.

#Setting names for the csv header headernames = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'Class'] #opening the csv file dataset = pd.read_csv("iris.csv", names = headernames) dataset.head()
Code language: PHP (php)
Data Sample
Data Sample

Convert the output text label to numeric representation.

#Seperating the input features and output labels X = dataset.iloc[:, :-1].values y = dataset.iloc[:, 4].values #converting text labels to numeric form labels, unique = pd.factorize(y)
Code language: PHP (php)

Let’s code our simple algorithm.

As always, we need to split our data into test and train samples.

#splitting data in test and train segments from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size = 0.40)
Code language: PHP (php)

We have used a 60–40 split for the total data.

With all the steps in place, it’s time to test the accuracy.

def KNNClassify(X_test, Y_train = y_train,X_train = X_train, k = 8): min_dist = [] #for every example in the training set, calculate eucledien distance against the test example for i,point in enumerate(X_train): d1 = (point[0]-X_test[0])**2 d2 = (point[1]-X_test[1])**2 d3 = (point[2]-X_test[2])**2 d4 = (point[3]-X_test[3])**2 dist = np.sqrt(d1+d2+d3+d4) #append the calculated distance in a list min_dist.append((i,dist)) #sort distances in ascending order min_dist.sort(key = takeSecond) #get top k nearest neighbours neighbours = min_dist[:k] #get index of the minimum distances idx = [] for tup in neighbours: idx.append(tup[0]) #check which label has majority output = Y_train[idx] values, counts = np.unique(output, return_counts=True) #return label with majority occurence max_idx = np.argmax(counts) return values[max_idx]
Code language: PHP (php)

We have used a helper function in the above code which is below.

#Creating a helper function def takeSecond(elem): return elem[1]
Code language: PHP (php)

With all the steps in place, it’s time to test the accuracy.

#getting predicted values using our algorithm predictions = list(map(KNNClassify, X_test))
Code language: PHP (php)
def accuracy(pred , y_test): count = 0 for i in range(len(pred)): if pred[i] == y_test[i]: count +=1 return print("Accuracy =", (count/len(pred))*100, "%")
Code language: PHP (php)

calling predictions…

#calling the accuracy function accuracy(predictions, y_test)
Code language: PHP (php)

96.67%!! That’s a very good number.


This might just be the smallest code we have to write for a machine learning algorithm but it yields an accuracy of 96.67%. Our initial statement stands, ML algorithms don’t need to be complex, they can be as simple as the KNN we just learned in this article. It would be fun to try this algorithm on multiple datasets and see how it performs on those.

Thank you for reading.

Comments 3
  1. Hey! This is a brilliant piece of work! I never knew KNN could be written so simply. I always had an idea that this is how I would do it if I needed to do it by hand, but an actual implementation of the code, with such accuracy is just brilliant! Thanks for the effort my man 🙂

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Chocolatey Package Manager: A sweet way to install software on Windows

Chocolatey Package Manager: A sweet way to install software on Windows

Since the dawn of Windows, the traditional way of installing software on Windows

DBSCAN Clustering Algorithm Implementation from scratch | Python

DBSCAN Clustering Algorithm Implementation from scratch | Python

The worlds most valuable resource is no longer oil, but data As surprising as

You May Also Like