KNN in Machine Learning (from Scratch !!)

K-Nearest Neighbours: Introduction

Birds of a feather flock together.

William Turner

The above quote perfectly sums up the algorithm that we are going to talk about in this post. KNN stands for K-Nearest Neighbours. It is a simple, easy-to-implement supervised machine learning algorithm that can be used to solve both classification and regression problems.

Note: Here, I assume that you have a basic idea about supervised machine learning and are aware of the difference between a classification problem and a regression problem. If this is not the case, please refer to this link. It will help you get your basics cleared.

In this blog post we will discuss the following things:

  1. K-Nearest Neighbours: Theory
  2. K-Nearest Neighbours: Algorithm
  3. K-Nearest Neighbours: Code

In the next section, we start the discussion on the theoretical aspects of the K-Nearest Neighbours algorithm.

K-Nearest Neighbours: Theory

The KNN algorithm is a lazy algorithm. Now what do I mean by the term lazy algorithm? A Lazy algorithm is an algorithm which does not induce a concise hypothesis from a given training set rather the inductive process is delayed until a test instance is given. In simple words, the algorithm does not start generalisation does until a query is made to it. The opposite of a lazy algorithm is an eager algorithm where the system tries to construct a general, input-independent target function during training of the system.

The KNN algorithm assumes that similar things exist in close proximity. For example, if we have a binary classification task where we want to classify a test data point into one of the two classes then the K-Nearest Neighbour algorithm will find K closest data points to the test data point and will assign it to the class which is most common in the set of K closest data points. The following image shows the decision boundary of a 3 Nearest Neighbour classifier on the iris dataset.

3-Class classification (k = 15, weights = 'uniform')

The KNN algorithm uses various distance functions to compute this proximity amongst the data points. Some of these functions are listed in the table below.

IdentifierClass NameDistance Function
“euclidean”EuclideanDistancesqrt(sum((x - y)^2))
“manhattan”ManhattanDistancesum(|x - y|)
“chebyshev”ChebyshevDistancemax(|x - y|)
“minkowski”MinkowskiDistancesum(w * |x - y|^p)^(1/p)
“wminkowski”WMinkowskiDistancesum(|w * (x - y)|^p)^(1/p)
“seuclidean”SEuclideanDistancesqrt(sum((x - y)^2 / V))
“mahalanobis”MahalanobisDistancesqrt((x - y)' V^-1 (x - y))

In the next section, we discuss the K-Nearest Neighbours algorithm and then implement it from scratch using Python programming language.

K-Nearest Neighbours: Algorithm

The KNN algorithm can be explained in the following steps:

  1. Load the data.
  2. Initialise the K to your chosen number of neighbours.
  3. For each data point,
    1. Calculate and store the distance between the queried data point and the selected data point.
    2. Sort these distances from smallest to largest i.e, in an ascending order.
    3. Extract the top k indexes from this ordered collection of distances.
  4. Get the labels for these selected top k indexes from the target variable.
  5. If the problem is a,
    • Classification problem, then return the mode of the extracted labels.
    • Regression problem, then return the mean of the extracted labels.

Note: Here, k is a hyperparameter and its ideal value will vary for each use case. The ideal value can be determined by experimenting the algorithm with different k values and selecting the one which gives the best performance. This can be done using Grid Search techniques.

In the next section, we implement the K-Nearest Neighbours algorithm from scratch using Python for a classification problem.

K-Nearest Neighbours: Code

Here, we take the case of a simple binary classification problem and use our custom KNN class to classify the data points. First, we generate synthetic data using scikit-learn’s make_classification() function and plot is using a scatter plot. The below attached code cell demonstrates the same,

# Make Custom Dataset
X, y = make_classification(n_samples=500, n_features=2, n_redundant = 0, random_state=10)

# Plot the data
sns.scatterplot(x = X[ : , 0], y = X[ : , 1], hue=y)
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Custom Dataset")

Next, we implement a custom KNN class for making the classifications. The below attached code cell demonstrates the same.

# Euclidean distance
We will be using Euclidean Distance as a metric to compute the distance between test point and the data points in the training set.
def euclidean_distance(x1, x2):
    return np.sqrt(np.sum((x1-x2)**2, axis=1))

# Create a custom KNN Class
class KNN_Custom:
    # Constructor
    def __init__(self, k = 5):
        self.k = k
    # Fit function 
    Takes in the training data
    def fit(self, X, y):
        self.X_train = X
        self.y_train = y
    # Make Predictions
    Generates predictions on the test data using the training data
    def predict(self, test_points):
        # Compute the distance between x and all other data points
        euc_distances = [euclidean_distance(test_point, self.X_train) for test_point in test_points]
        # Sort the K Nearest Neighbors
        k_indices = [np.argsort(distances)[ : self.k] for distances in euc_distances]
        k_labels = [[self.y_train[i] for i in indices] for indices in k_indices]
        # Prepare the output and return the predictions
        final_predictions = []
        for label in k_labels:
            predicted_label = Counter(label).most_common(1)[0][0]
        return np.array(final_predictions)

Let’s now test the algorithm on the test set and compute the model’s performance. I have chosen the accuracy as the performance metric because of the balanced distribution of both the classes.

# Split the data into train and test
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Test the algorithm
knn = KNN_Custom(), y_train)

# Make Predictions
knn_predictions = knn.predict(X_test)

# Compute Accuracy
knn_acc = np.sum(knn_predictions == y_test)/X_test.shape[0] * 100
print("Accuracy of the custom KNN classifier is {} %".format(knn_acc))

We are able to get an accuracy of 96% using our custom KNN class. Now, let’s visualise the working of the algorithm using a scatter plot.

# Let's visualise the working of the algorithm
test_point = np.array([3, -2]).reshape(1,2)

# Clasify the test point
test_prediction = knn.predict(test_point)
print("Predicted Class: ", test_prediction[0])

# Plot the test point along with the data
sns.scatterplot(X[ : , 0], X[ : , 1], hue=y)
sns.scatterplot(test_point[ : , 0], test_point[ : , 1], color='green', marker= 'X', s=200, label = 'Test Point')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Custom Dataset")

From the above plot, it can be seen that the test point lies close to the data points that belong to the class 1 and our KNN model also classifies this test point into the class 1.

Next, we compare the performance of our model with Scikit Learn’s in-built KNeighborsClassifier class.

K-Nearest Neighbours: Comparison

Below, we use Scikit Learn’s in-built KNN classifier and check the performance on the test set. We also compare its performance with our custom KNN model.

# Import KNN from sklearn
from sklearn.neighbors import KNeighborsClassifier

# Create the object and fit the model
knn_2 = KNeighborsClassifier(metric='euclidean'), y_train)

# Make Predictions
knn_predictions_2 = knn_2.predict(X_test)

# Compute Accuracy
knn_acc_2 = np.sum(knn_predictions_2 == y_test)/X_test.shape[0] * 100
print("Accuracy of sklearn's KNN classiifier is {} %".format(knn_acc_2))

We can see that Scikit Learn’s KNN model gives the exact same result as our custom KNN model. You can access the entire code here.


So, in this blog post we discussed about the K-Nearest Neighbours algorithm in detail. We also implemented it from scratch and used it to solve a simple Binary Classification problem. One point to note here is, that KNN is used as a baseline in most of the use cases because of its simplicity. The code used in the above sections can be found on my Kaggle profile using this link.

I hope you find this blog post helpful. Please do subscribe to my blog, this really motivates to bring more informative content on Data Science and Machine Learning. You can connect with me over LinkedIn as well. I will be very happy to have a conversation with you. I will catch you in another blog post till then, Happy Learning 🙂

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: