Cross-validation using KNN

Deepak Jain
Towards Data Science
6 min readJul 21, 2020

--

This is the 3rd article in the KNN series. In case, you haven’t read the first 2 parts I suggest you go through them first. Part-1, Part-2

In this article, we will understand what is cross-validation, why it's needed, and what is k-fold cross-validation?

In order to better understand the need for cross-validation, let me first recap on how to determine the right value of “K”?

Suppose, we are given a dataset and we split it into training data and test data in the ratio 70:30. Meaning, I keep 70% of my total data to train my model and rest 30% to test it. Next, I train my model with different values of “K” and capture its accuracy on my test data. Assume we get a table like below:

Image by Author

Now, if you observe then at K=3 I get the highest accuracy of 90% and after that, we see a decreasing trend in the accuracy. So, basis that we come to the conclusion that my appropriate value for K=3.

This all sounds good but there is a small problem with it. Let me ask you a very simple yet interesting question.

What is the objective of machine learning?

If I consider the above example and based on my understanding, for a future unseen data point how accurately my model will predict the class label. When an algorithm performs well on an unseen data point, its called generalization. The whole objective of machine learning is generalizing.

If you think about KNN, we used the test data to basically determine the right value of K and the train data to find the nearest neighbors. We got the accuracy of 90% on the test data which we also used to determine the right value of “K”. Can I say that since I achieved an accuracy of 90% on my test data, I will be able to retain this approximate level of accuracy on my future unseen data? The answer is NO.

In order to confidently say that we can achieve an accuracy of approximately 90% on future unseen data, I need to first test this model on unseen data.

But how do I do that? I had my data set which I already split into 70:30 ratio of training and test data. I have no more data available with me.

In order to solve this problem, I introduce you to the concept of cross-validation.

In cross-validation, instead of splitting the data into two parts, we split it into 3. Training data, cross-validation data, and test data. Here, we use training data for finding nearest neighbors, we use cross-validation data to find the best value of “K” and finally we test our model on totally unseen test data. This test data is equivalent to the future unseen data points.

Consider the below diagram which clearly distinguishes the splitting for a better understanding:

Image by author

The section on the left is without cross-validation and the one on the right is with cross-validation.

Under the cross-validation part, we use D_Train and D_CV to find KNN but we don’t touch D_Test. Once we find an appropriate value of “K” then we use that K-value on D_Test, which also acts as a future unseen data, to find how accurately the model performs.

Now, that we have understood the concept of cross-validation and the need for it, let's address a problem associated with it.

Problem:

If you refer to the above image, we have split the data in the ratio of 60:20:20, where we use 60% of the data to train our model, 20% for cross-validation, and rest 20% for testing. In this process of cross-validation, we are losing almost 20% of the data, which we would have otherwise used for training and it’s a known fact that more the training data, better is the algorithm. So, is there a way to somehow use that 20% of the cross-validation data under training data?

And the answer to the question is k-fold cross-validation

So, what happens in a k-fold cross-validation? Consider the below example:

Image by author

After splitting the total data set (D_n) into training (D_Train) and test (D_Test) data set, in the ratio of 80:20, we further randomly split the training data into 4 equal parts.

Image by Author

D1, D2, D3, and D4 are the four randomly split equal parts of D_Train. Once done with the splitting, we proceed as follows:

Image by Author

Step-1: For K=1, I pick D1, D2, and D3 as my training data set and set D4 as my cross-validation data and find the nearest neighbors and calculate its accuracy.

Step-2: Again, for K=1, I pick D1, D2, and D4 as my training data set and set D3 as my cross-validation data, I find the nearest neighbors and calculate its accuracy.

I repeat the above steps with D2 and D1 as my cross-validation data set and calculate the corresponding accuracy. Once done with it, I get 4 accuracies for the same value of K=1. So I consider the mean of these accuracies and assign it as the final value when my K=1.

Now, I repeat the above steps for K=2 and find the mean accuracy for K=2. So on and so forth, I calculate the accuracies for different values of K.

Now, note that for each value of K I had to compute the accuracy 4 times. This is because I randomly split my training data set into 4 equal parts. Suppose I had randomly split my data set into 5 equal parts then I would have to compute 5 different accuracies for each value of K and take their mean.

Please Note: Capital “K” stands for the K value in KNN and lower “k” stands for k value in k-fold cross-validation

So, k value in k-fold cross-validation for the above example is 4 (i.e k=4), had we split the training data into 5 equal parts, the value of k=5.

k = number of parts we randomly split our training data set into.

Now, we are using the entire 80% of our data to compute the nearest neighbors as well as the K value in KNN.

Just one drawback with k-fold cross-validation is that we are repeating the computations for each value of K (of KNN). So it basically increases the time complexity.

Conclusion

We have understood the concept of cross-validation, why we need it and what do we mean by k-fold cross-validation. In spite of having high time complexity, the process is worth it as it increases the generalization of the model.

Thanks for Reading!

I hope that you find these resources and ideas helpful in your data science journey.

Deepak Jain

--

--