You are on page 1of 10

Handwritten digit classification using K-nearest neighbors

algorithm
http://emaraic.com/blog/handwritten-digit-classification

Introduction

One of the most important aims of machine learning is to classify data into classes. For example,
classifying an email to be spam or ham, a tumor is a malignant or benign, or classifying
handwritten digits into one of the 10 classes.

K-Nearest Neighbors

In a nutshell, Suppose we have N training objects each of them represented by a vector "x" and a
label "c" to classify a new object xnew with K-nearest neighbors we find the K training points that
are close to xnew and then assign "c" to the majority class amongst these neighbors.

For instance, Stars and diamonds donate training points and X1new and X2new donate testing
points. When choosing K=3, X1new will be assigned to the star class and X2new will be assigned
to the diamond class.
I will explain using an exercise from Bayesian Reasoning and Machine Learning book chapter
14: from here.

Tools:

1. Octave or Matlab. I used Octave, you can download from here.


2. BRMLtoolbox you can download from here , I used the non-OO code.

Exercise 14.1. The file NNdata.mat contains training and test data for the handwritten digits 5
and 9. Using leave one out cross-validation, find the optimal K in K-nearest neighbors, and use
this to compute the classification accuracy of the method on the test data.

Dataset Description

In this exercise we have dataset for digits 5 and 9 composed of test and training matrices for
each. Each training matrix has a size of 784x600 containing 600 columns represent training
objects and rows representing a digit with 784 (28x28) pixels. Each test matrix has a size of
784x292 containing 292 (columns) represent test objects and rows representing a digit with 784
(28x28) pixels.
you can see the contents of NNdata.mat file using the following commands.

>> load("./data/NNdata.mat")
>> whos
Variables in the current scope:

Attr Name Size Bytes Class


==== ==== ==== ===== =====
test5 784x292 1831424 double
test9 784x292 1831424 double
train5 784x600 3763200 double
train9 784x600 3763200 double

Total is 1398656 elements using 11189248 bytes

Here is some samples from train5, train9, test5, and test9 arrays. I used the code below to
generate image for train and test class 5.

data=load('./data/NNdata.mat'); % load training and testing data

train5=data.train5; % class 5 training data for digit 5


train9=data.train9; % class 9 training data for digit 9

test5=data.test5; % test data for digit 5


test9=data.test9; % test data for digit 9

subplot (4, 2, 1);


imagesc(reshape (data.train5(:,2),28,28)');

subplot (4, 2, 2);


imagesc(reshape (data.test5(:,20),28,28)');
subplot (4, 2, 3);
imagesc(reshape (data.train5(:,110),28,28)');

subplot (4, 2, 4);


imagesc(reshape (data.test5(:,222),28,28)');

subplot (4, 2, 5);


imagesc(reshape (data.train5(:,61),28,28)');

subplot (4, 2, 6);


imagesc(reshape (data.test5(:,145),28,28)');
After understating the dataset structure, let's start the solution. The first requirement is to define
the optimal K using leave one out cross validation and then calculate the accuracy of the
algorithm on test data.
BRMLtoolbox has a method nearNeigh(traindata, testdata, trainlabel, k). Next table describes
each input parameter of nearNeigh function.

Leave one out cross validation (LOOCV)

In this validation schema, the dataset is divided into subsets equal to N, the number of data
points in the set. That means that N separate times, the function approximator is trained on all the
data except for one point and a prediction is made for that point.

Steps of leave one out cross validation:

1. Initialize a counter for error.


2. Pick one train object with its label from a data matrix contains train5 and train9 and the
correspondent label matrix.
3. Remove this object and its label from the two matrices.
4. Predict that object with the matrix except this element.
5. If the predicted label does not equal to the known label "5 or 9", increment the error counter.
6. Do the same for all objects in the matrix.
7. Iterate these steps over different K values. This step may consume a lot of time for large value of
k.
8. Choose the optimal K that has the smallest error counter.

After choosing the optimal K value from the output of LOOCV, testing phase has came with this
K value.The test5 and test9 matrices are used if the expected label of each vector of test data
does not equal to the exact label, increment an error counter. After finishing measure the
accuracy of classification by this formula (number of test objects - total error)/ number of test
objects x 100%
The code

This code depends on dataset file and other files you can get them from github repository from
here.

data=load('./data/NNdata.mat'); % load training and testing data

train5 = data.train5; % class 5 training data for digit 5


train9 = data.train9; % class 9 training data for digit 9

test5 = data.test5; % test data for digit 5


test9 = data.test9; % test data for digit 9

traindata = [train5 train9];


trainlabel = [5*ones(1,600), 9*ones(1,600)];

% Find the optimal K


totalerror = [];
for ks = 1:20
error = 0;
for t5 = 1:1200
temp5label = trainlabel;
temp5 = traindata;
traindata(:,t5) = [];
trainlabel(t5) = [];
y5 = nearNeigh(traindata,temp5(:,t5) , trainlabel,ks); % find
nearest

if(y5~=temp5label(t5))
error++;
end % end if

traindata = temp5;
trainlabel = temp5label;
end % end iterating samples
totalerror = [totalerror error];
disp(strcat('For k= ', num2str(ks) ,' Error = ' , num2str(error)));
end %k for loop

kvalues = 1:20;
temperror = totalerror;
minerror = find(temperror==min(temperror));
bestk = kvalues(minerror);
disp(strcat('Best k= ', num2str(bestk)));

%% Get classification accuracy for best k


error5 = 0;
error9 = 0;

testout5 = nearNeigh(traindata,test5 , trainlabel,bestk); % find nearest


testout9 = nearNeigh(traindata,test9 , trainlabel,bestk); % find nearest

for i = 1:292
if(testout5(i)~=5)
error5++;
end % end if
if(testout9(i)~=9)
error9++;
end % end ifd
end%% end for

accuracy = ((584-(error5+error9))/584)*100;
disp(strcat('Accuracy= ', num2str(accuracy),'%'));

%% Print chart between different values of keys (x) and Errors (y)
figure; hold on;
plot(kvalues,totalerror);
set(gca, "ylabel", text("string", "Error", "fontsize", 25));
set(gca, "xlabel", text("string", "K values", "fontsize", 25));
set(gca, "xtick", [1:20]);

Output
>>
For k=1 Error =22
For k=2 Error =22
For k=3 Error =21
For k=4 Error =26
For k=5 Error =28
For k=6 Error =29
For k=7 Error =29
For k=8 Error =32
For k=9 Error =32
For k=10 Error =35
For k=11 Error =35
For k=12 Error =36
For k=13 Error =37
For k=14 Error =41
For k=15 Error =42
For k=16 Error =44
For k=17 Error =47
For k=18 Error =50
For k=19 Error =52
For k=20 Error =54
Best k=3
Accuracy=97.603%
From the above results, optimal K value is 3 and the classification accuracy is 97.603%.

You might also like