Lesson 06 - Introduction to the MNIST Dataset

The following topics are discussed in this notebook:

  • Using a neural network for classifying image data.
  • Scaling features.
In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd

import keras
from keras.models import Sequential
from keras.layers import Dense, Flatten
from tensorflow import set_random_seed

from sklearn.model_selection import train_test_split
from keras.datasets import mnist
Using TensorFlow backend.

Load the MNIST Data

The MNIST dataset consists of 70,000, 28x28 black-and-white images of handwritten digits.

In [2]:
(X_train, y_train), (X_holdout, y_holdout) = mnist.load_data()

X_val, X_test, y_val, y_test = train_test_split(X_holdout, y_holdout, test_size = 0.5, random_state=1)

print(X_train.shape)
print(y_train.shape)
print(X_val.shape)
print(y_val.shape)
print(X_test.shape)
print(y_test.shape)
(60000, 28, 28)
(60000,)
(5000, 28, 28)
(5000,)
(5000, 28, 28)
(5000,)

How is a digit represented?

In [3]:
np.set_printoptions(linewidth=120)
mydigit = X_train[0]
print(mydigit)
np.set_printoptions(linewidth=75)
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3  18  18  18 126 136 175  26 166 255 247 127   0   0   0   0]
 [  0   0   0   0   0   0   0   0  30  36  94 154 170 253 253 253 253 253 225 172 253 242 195  64   0   0   0   0]
 [  0   0   0   0   0   0   0  49 238 253 253 253 253 253 253 253 253 251  93  82  82  56  39   0   0   0   0   0]
 [  0   0   0   0   0   0   0  18 219 253 253 253 253 253 198 182 247 241   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  80 156 107 253 253 205  11   0  43 154   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  14   1 154 253  90   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 139 253 190   2   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  11 190 253  70   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  35 241 225 160 108   1   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  81 240 253 253 119  25   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  45 186 253 253 150  27   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  16  93 252 253 187   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 249 253 249  64   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  46 130 183 253 253 207   2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  39 148 229 253 253 253 250 182   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  24 114 221 253 253 253 253 201  78   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  23  66 213 253 253 253 253 198  81   2   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  18 171 219 253 253 253 253 195  80   9   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0  55 172 226 253 253 253 253 244 133  11   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0 136 253 253 253 212 135 132  16   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
In [4]:
plt.imshow(mydigit, cmap=cm.binary)
plt.axis('off')
plt.show()
In [5]:
sel = np.random.choice(range(60000), 40, replace=False)
X_sel = X_train[sel]
y_sel = y_train[sel]

plt.close()
plt.rcParams["figure.figsize"] = [16,10]
for i in range(40):
    plt.subplot(5,8,i+1)
    plt.imshow(X_sel[i], cmap=cm.binary)
    plt.text(-1, 10, s = str(int(y_sel[i])), fontsize=16, color='b')
    plt.axis('off')
plt.show()

Scale the Data

In [6]:
Xs_train = X_train / 255
Xs_val = X_val / 255
Xs_test = X_test / 255

Training the network.

In [7]:
%%time

np.random.seed(1)
set_random_seed(1)

model = Sequential()
model.add(Flatten(input_shape=(28,28)))
model.add(Dense(512, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))

opt = keras.optimizers.Adam(lr = 0.001)
model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

h = model.fit(Xs_train, y_train, batch_size=1024, epochs=20, validation_data=(Xs_val, y_val), verbose=2)
Train on 60000 samples, validate on 5000 samples
Epoch 1/20
 - 2s - loss: 0.5188 - acc: 0.8607 - val_loss: 0.2156 - val_acc: 0.9340
Epoch 2/20
 - 1s - loss: 0.1752 - acc: 0.9492 - val_loss: 0.1386 - val_acc: 0.9600
Epoch 3/20
 - 1s - loss: 0.1199 - acc: 0.9650 - val_loss: 0.1103 - val_acc: 0.9648
Epoch 4/20
 - 1s - loss: 0.0864 - acc: 0.9748 - val_loss: 0.0892 - val_acc: 0.9736
Epoch 5/20
 - 1s - loss: 0.0656 - acc: 0.9811 - val_loss: 0.0804 - val_acc: 0.9754
Epoch 6/20
 - 1s - loss: 0.0501 - acc: 0.9853 - val_loss: 0.0742 - val_acc: 0.9782
Epoch 7/20
 - 1s - loss: 0.0401 - acc: 0.9886 - val_loss: 0.0704 - val_acc: 0.9794
Epoch 8/20
 - 1s - loss: 0.0318 - acc: 0.9913 - val_loss: 0.0637 - val_acc: 0.9798
Epoch 9/20
 - 1s - loss: 0.0256 - acc: 0.9935 - val_loss: 0.0681 - val_acc: 0.9800
Epoch 10/20
 - 1s - loss: 0.0214 - acc: 0.9947 - val_loss: 0.0635 - val_acc: 0.9812
Epoch 11/20
 - 1s - loss: 0.0156 - acc: 0.9964 - val_loss: 0.0630 - val_acc: 0.9818
Epoch 12/20
 - 1s - loss: 0.0125 - acc: 0.9976 - val_loss: 0.0673 - val_acc: 0.9800
Epoch 13/20
 - 1s - loss: 0.0099 - acc: 0.9984 - val_loss: 0.0665 - val_acc: 0.9802
Epoch 14/20
 - 1s - loss: 0.0073 - acc: 0.9990 - val_loss: 0.0653 - val_acc: 0.9820
Epoch 15/20
 - 1s - loss: 0.0062 - acc: 0.9991 - val_loss: 0.0644 - val_acc: 0.9820
Epoch 16/20
 - 1s - loss: 0.0047 - acc: 0.9995 - val_loss: 0.0676 - val_acc: 0.9804
Epoch 17/20
 - 1s - loss: 0.0047 - acc: 0.9994 - val_loss: 0.0663 - val_acc: 0.9828
Epoch 18/20
 - 1s - loss: 0.0030 - acc: 0.9998 - val_loss: 0.0666 - val_acc: 0.9822
Epoch 19/20
 - 1s - loss: 0.0023 - acc: 0.9999 - val_loss: 0.0681 - val_acc: 0.9820
Epoch 20/20
 - 1s - loss: 0.0021 - acc: 0.9999 - val_loss: 0.0677 - val_acc: 0.9832
Wall time: 14.4 s

Visualize Training Process

In [8]:
plt.rcParams["figure.figsize"] = [8,4]
plt.subplot(1,2,1)
plt.plot(h.history['acc'], label='Training')
plt.plot(h.history['val_acc'], label='Validation')
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()

plt.subplot(1,2,2)
plt.plot(h.history['loss'], label='Training')
plt.plot(h.history['val_loss'], label='Validation')
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend()
plt.show()
In [9]:
score = model.evaluate(X_test, y_test, verbose=0)
print('Testing loss:    ', score[0])
print('Testing accuracy:', score[1])
Testing loss:     0.3494494982240365
Testing accuracy: 0.978

Confusion Matrix

In [10]:
from sklearn.metrics import confusion_matrix, classification_report

y_pred = model.predict_classes(X_test)

confmat = confusion_matrix(y_test, y_pred)
df = pd.DataFrame(confmat)

df
Out[10]:
0 1 2 3 4 5 6 7 8 9
0 495 0 0 0 1 0 0 1 1 1
1 0 570 3 0 0 1 1 0 1 0
2 3 2 517 1 2 0 1 5 2 0
3 0 0 3 498 1 2 0 2 0 1
4 2 0 2 1 470 0 0 0 1 6
5 0 0 0 3 1 393 1 1 0 2
6 2 0 1 0 2 3 469 0 0 0
7 0 1 6 1 1 0 0 508 2 3
8 5 1 3 4 3 2 0 2 486 3
9 1 1 0 2 5 0 0 1 0 484
In [11]:
print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       499
           1       0.99      0.99      0.99       576
           2       0.97      0.97      0.97       533
           3       0.98      0.98      0.98       507
           4       0.97      0.98      0.97       482
           5       0.98      0.98      0.98       401
           6       0.99      0.98      0.99       477
           7       0.98      0.97      0.98       522
           8       0.99      0.95      0.97       509
           9       0.97      0.98      0.97       494

   micro avg       0.98      0.98      0.98      5000
   macro avg       0.98      0.98      0.98      5000
weighted avg       0.98      0.98      0.98      5000

Exploring the misclassified digits:

In [12]:
import math
# Find misclassified samples in the test set.
sel = y_pred != y_test
n_mc = np.sum(sel) # number misclassified

X_mc = X_test[sel,:]
y_mc = y_test[sel]
yp_mc = y_pred[sel]

idx = np.argsort(y_mc)
X_mc = X_mc[idx,:]
y_mc = y_mc[idx]
yp_mc = yp_mc[idx]

rows = math.ceil(n_mc / 6)

plt.figure(figsize=(12,30))
for i in range(0, n_mc):
    plt.subplot(rows,6,i+1)
    plt.imshow(X_mc[i], cmap=cm.binary)
    plt.text(-1, 10, s = str(int(y_mc[i])), fontsize=16, color='b')
    plt.text(-1, 16, s = str(int(yp_mc[i])), fontsize=16, color='r')
    plt.axis('off')