Naive Bayes MNIST Digit Classification with Scikit Learn

Objective

  • Use Naive Bayes model to classify handwritten digit using scikit-learn 
In this post we will try to classify handwritten digits using a simple linear model. The purpose of this post is to get familiar with data preparation, dividing the data-set into training and testing sets, create and train a model using scikit-learn API and evaluate it.

First let's download the data
[pre class="brush:python"] from sklearn.datasets import fetch_mldata mnist = fetch_mldata("MNIST original") print("Column names = ", mnist.COL_NAMES) print("DESC = ", mnist.DESCR) print("Shape of data = ", mnist["data"].shape) print("Shape of target = ", mnist["target"].shape) [/pre]
Column names = ['label', 'data']
DESC = mldata.org dataset: mnist-original
Shape of data = (70000, 784)
Shape of target = (70000,)
From the output we can see that there are total 70,000 samples. Each sample is a 1D vector of size 784. The MNIST data-set contains gray-scale images of digits from 0 to 9 and the dimension of each image is 28 * 28 pixels. When we "flatten" a 2D matrix of size 28 * 28, we get a vector of size 784. That is why in the data-set each sample is a single vector.

Let's visualize some data
[pre class="brush:python"]
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

def plot_images(images, labels):
    n_cols = min(5, len(images))
    n_rows = len(images) // n_cols
    fig = plt.figure(figsize=(8, 8))

    for i in range(n_rows * n_cols):
        sp = fig.add_subplot(n_rows, n_cols, i+1)
        plt.axis("off")
        plt.imshow(images[i], cmap=plt.cm.gray)
        sp.set_title(labels[i])
    plt.show()

# get random 20 indices
p = np.random.permutation(len(mnist.data))
p = p[:20]

# convert the data into 2D array before passing it to the plotting function
plot_images(mnist.data[p].reshape(-1, 28, 28), mnist.target[p])
[/pre]
Random images and their labels

Next, we'll normalize the data set so that the values are between 0 and 1. Since a value of pixel lies between 0 and 255, we will divide every pixel by 255 which will achieve what we want to do. Then we will split it into training and testing set. Our training set will contain 70% of the total data-set and testing set will contain the remaining 30%.

[pre class="brush:python"]
X = mnist.data / 255.0
Y = mnist.target
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3)

print("Total training samples = ", X_train.shape[0])
print("Total testing samples = ", X_test.shape[0])
[/pre]
Total training samples = 49000
Total testing samples = 21000

Now we are ready to build the model. We'll use a simple Multinomial model with the default parameters. It trains within one second in my machine!

[pre class="brush:python"]
from sklearn.naive_bayes import MultinomialNB

classifier = MultinomialNB()
classifier.fit(X_train, Y_train)
[/pre]

Every scikit-learn model has a function called fit to train it. It needs two parameters - data and its corresponding labels.

The model is trained and now we can evaluate it in the testing set

[pre class="brush:python"]
from sklearn.metrics import accuracy_score, classification_report
predicted = classifier.predict(X_test)

print("Accuracy of model = %2f%%" % (accuracy_score(Y_test, predicted)*100))
print(classification_report(Y_test, predicted))
[/pre]

Accuracy of model = 82.404762%
             precision    recall  f1-score   support
        0.0       0.92      0.91      0.92      2067
        1.0       0.88      0.93      0.91      2314
        2.0       0.87      0.82      0.85      2115
        3.0       0.79      0.81      0.80      2178
        4.0       0.82      0.74      0.78      2024
        5.0       0.86      0.65      0.74      1889
        6.0       0.88      0.91      0.89      2022
        7.0       0.94      0.83      0.88      2212
        8.0       0.65      0.78      0.71      2068
        9.0       0.70      0.83      0.76      2111
avg / total       0.83      0.82      0.82     21000

The model achieved an accuracy of 82%. Not bad for such a simple model. If we look at the classification report, we can see that the digits 8 and 9 have fairly low precision compared to others. Similarly the recall of digit 4 and 5 are also low. For details about Precision and Recall, check out https://en.wikipedia.org/wiki/Precision_and_recall

Before we wrap up, let's plot some predictions.
[pre class="brush:python"]
p = np.random.permutation(len(X_test))
p = p[:20]
plot_images(X_test[p].reshape(-1, 28, 28), predicted[p])
[/pre]


Predictions made by the model





Comments