Lesson 31 - The MNIST Dataset

The following topics are discussed in this notebook:

  • Using a neural network for classifying image data.
  • Scaling features.
In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import pandas as pd

import keras
from keras.models import Sequential
from keras.layers import Dense, Flatten
from tensorflow import set_random_seed

from sklearn.model_selection import train_test_split
from keras.datasets import mnist
Using TensorFlow backend.

Load the MNIST Data

The MNIST dataset consists of 70,000, 28x28 black-and-white images of handwritten digits.

In [2]:
(X_train, y_train), (X_holdout, y_holdout) = mnist.load_data()

X_val, X_test, y_val, y_test = train_test_split(X_holdout, y_holdout, test_size = 0.5, random_state=1)

print(X_train.shape)
print(y_train.shape)
print(X_val.shape)
print(y_val.shape)
print(X_test.shape)
print(y_test.shape)
(60000, 28, 28)
(60000,)
(5000, 28, 28)
(5000,)
(5000, 28, 28)
(5000,)

How is a digit represented?

In [3]:
np.set_printoptions(linewidth=120)
mydigit = X_train[0]
print(mydigit)
np.set_printoptions(linewidth=75)
[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3  18  18  18 126 136 175  26 166 255 247 127   0   0   0   0]
 [  0   0   0   0   0   0   0   0  30  36  94 154 170 253 253 253 253 253 225 172 253 242 195  64   0   0   0   0]
 [  0   0   0   0   0   0   0  49 238 253 253 253 253 253 253 253 253 251  93  82  82  56  39   0   0   0   0   0]
 [  0   0   0   0   0   0   0  18 219 253 253 253 253 253 198 182 247 241   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  80 156 107 253 253 205  11   0  43 154   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  14   1 154 253  90   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 139 253 190   2   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  11 190 253  70   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  35 241 225 160 108   1   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  81 240 253 253 119  25   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  45 186 253 253 150  27   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  16  93 252 253 187   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 249 253 249  64   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  46 130 183 253 253 207   2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  39 148 229 253 253 253 250 182   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  24 114 221 253 253 253 253 201  78   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  23  66 213 253 253 253 253 198  81   2   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  18 171 219 253 253 253 253 195  80   9   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0  55 172 226 253 253 253 253 244 133  11   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0 136 253 253 253 212 135 132  16   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]
In [4]:
plt.imshow(mydigit, cmap=cm.binary)
plt.axis('off')
plt.show()
In [5]:
sel = np.random.choice(range(60000), 40, replace=False)
X_sel = X_train[sel]
y_sel = y_train[sel]

plt.close()
plt.rcParams["figure.figsize"] = [16,10]
for i in range(40):
    plt.subplot(5,8,i+1)
    plt.imshow(X_sel[i], cmap=cm.binary)
    plt.text(-1, 10, s = str(int(y_sel[i])), fontsize=16, color='b')
    plt.axis('off')
plt.show()

Scale the Data

In [6]:
Xs_train = X_train / 255
Xs_val = X_val / 255
Xs_test = X_test / 255

Training the network.

In [7]:
%%time

np.random.seed(1)
set_random_seed(1)

model = Sequential()
model.add(Flatten(input_shape=(28,28)))
model.add(Dense(512, activation='sigmoid'))
model.add(Dense(256, activation='sigmoid'))
model.add(Dense(10, activation='softmax'))

opt = keras.optimizers.Adam(lr = 0.001)
model.compile(loss='sparse_categorical_crossentropy', optimizer=opt, metrics=['accuracy'])

h = model.fit(Xs_train, y_train, batch_size=1024, epochs=50, validation_data=(Xs_val, y_val), verbose=2)
Train on 60000 samples, validate on 5000 samples
Epoch 1/50
 - 10s - loss: 1.3205 - acc: 0.6802 - val_loss: 0.5565 - val_acc: 0.8656
Epoch 2/50
 - 7s - loss: 0.4240 - acc: 0.8892 - val_loss: 0.3392 - val_acc: 0.9106
Epoch 3/50
 - 9s - loss: 0.3126 - acc: 0.9121 - val_loss: 0.2765 - val_acc: 0.9202
Epoch 4/50
 - 7s - loss: 0.2686 - acc: 0.9220 - val_loss: 0.2479 - val_acc: 0.9260
Epoch 5/50
 - 8s - loss: 0.2411 - acc: 0.9296 - val_loss: 0.2270 - val_acc: 0.9314
Epoch 6/50
 - 8s - loss: 0.2192 - acc: 0.9362 - val_loss: 0.2126 - val_acc: 0.9344
Epoch 7/50
 - 6s - loss: 0.2002 - acc: 0.9414 - val_loss: 0.1929 - val_acc: 0.9406
Epoch 8/50
 - 7s - loss: 0.1836 - acc: 0.9468 - val_loss: 0.1763 - val_acc: 0.9476
Epoch 9/50
 - 10s - loss: 0.1684 - acc: 0.9509 - val_loss: 0.1639 - val_acc: 0.9482
Epoch 10/50
 - 9s - loss: 0.1565 - acc: 0.9546 - val_loss: 0.1545 - val_acc: 0.9542
Epoch 11/50
 - 8s - loss: 0.1441 - acc: 0.9582 - val_loss: 0.1439 - val_acc: 0.9560
Epoch 12/50
 - 6s - loss: 0.1331 - acc: 0.9612 - val_loss: 0.1391 - val_acc: 0.9570
Epoch 13/50
 - 9s - loss: 0.1236 - acc: 0.9639 - val_loss: 0.1287 - val_acc: 0.9602
Epoch 14/50
 - 8s - loss: 0.1137 - acc: 0.9667 - val_loss: 0.1260 - val_acc: 0.9630
Epoch 15/50
 - 7s - loss: 0.1067 - acc: 0.9690 - val_loss: 0.1153 - val_acc: 0.9638
Epoch 16/50
 - 7s - loss: 0.0993 - acc: 0.9708 - val_loss: 0.1100 - val_acc: 0.9654
Epoch 17/50
 - 5s - loss: 0.0909 - acc: 0.9739 - val_loss: 0.1042 - val_acc: 0.9672
Epoch 18/50
 - 7s - loss: 0.0849 - acc: 0.9759 - val_loss: 0.1011 - val_acc: 0.9678
Epoch 19/50
 - 6s - loss: 0.0784 - acc: 0.9778 - val_loss: 0.0960 - val_acc: 0.9680
Epoch 20/50
 - 7s - loss: 0.0724 - acc: 0.9793 - val_loss: 0.0939 - val_acc: 0.9706
Epoch 21/50
 - 6s - loss: 0.0675 - acc: 0.9807 - val_loss: 0.0896 - val_acc: 0.9716
Epoch 22/50
 - 6s - loss: 0.0621 - acc: 0.9823 - val_loss: 0.0859 - val_acc: 0.9710
Epoch 23/50
 - 6s - loss: 0.0580 - acc: 0.9836 - val_loss: 0.0806 - val_acc: 0.9746
Epoch 24/50
 - 5s - loss: 0.0535 - acc: 0.9849 - val_loss: 0.0796 - val_acc: 0.9740
Epoch 25/50
 - 7s - loss: 0.0497 - acc: 0.9862 - val_loss: 0.0756 - val_acc: 0.9748
Epoch 26/50
 - 8s - loss: 0.0481 - acc: 0.9863 - val_loss: 0.0764 - val_acc: 0.9764
Epoch 27/50
 - 7s - loss: 0.0433 - acc: 0.9883 - val_loss: 0.0724 - val_acc: 0.9760
Epoch 28/50
 - 8s - loss: 0.0397 - acc: 0.9894 - val_loss: 0.0705 - val_acc: 0.9774
Epoch 29/50
 - 7s - loss: 0.0365 - acc: 0.9901 - val_loss: 0.0676 - val_acc: 0.9772
Epoch 30/50
 - 8s - loss: 0.0333 - acc: 0.9915 - val_loss: 0.0676 - val_acc: 0.9770
Epoch 31/50
 - 7s - loss: 0.0314 - acc: 0.9920 - val_loss: 0.0668 - val_acc: 0.9792
Epoch 32/50
 - 7s - loss: 0.0286 - acc: 0.9931 - val_loss: 0.0661 - val_acc: 0.9788
Epoch 33/50
 - 7s - loss: 0.0265 - acc: 0.9935 - val_loss: 0.0645 - val_acc: 0.9798
Epoch 34/50
 - 8s - loss: 0.0246 - acc: 0.9942 - val_loss: 0.0628 - val_acc: 0.9794
Epoch 35/50
 - 8s - loss: 0.0224 - acc: 0.9948 - val_loss: 0.0664 - val_acc: 0.9786
Epoch 36/50
 - 10s - loss: 0.0209 - acc: 0.9953 - val_loss: 0.0609 - val_acc: 0.9794
Epoch 37/50
 - 10s - loss: 0.0192 - acc: 0.9958 - val_loss: 0.0684 - val_acc: 0.9786
Epoch 38/50
 - 9s - loss: 0.0178 - acc: 0.9963 - val_loss: 0.0634 - val_acc: 0.9804
Epoch 39/50
 - 7s - loss: 0.0162 - acc: 0.9969 - val_loss: 0.0588 - val_acc: 0.9804
Epoch 40/50
 - 6s - loss: 0.0146 - acc: 0.9976 - val_loss: 0.0639 - val_acc: 0.9796
Epoch 41/50
 - 6s - loss: 0.0135 - acc: 0.9977 - val_loss: 0.0618 - val_acc: 0.9804
Epoch 42/50
 - 7s - loss: 0.0123 - acc: 0.9981 - val_loss: 0.0604 - val_acc: 0.9816
Epoch 43/50
 - 7s - loss: 0.0113 - acc: 0.9982 - val_loss: 0.0612 - val_acc: 0.9804
Epoch 44/50
 - 11s - loss: 0.0106 - acc: 0.9985 - val_loss: 0.0589 - val_acc: 0.9816
Epoch 45/50
 - 8s - loss: 0.0098 - acc: 0.9987 - val_loss: 0.0603 - val_acc: 0.9804
Epoch 46/50
 - 7s - loss: 0.0088 - acc: 0.9989 - val_loss: 0.0595 - val_acc: 0.9806
Epoch 47/50
 - 7s - loss: 0.0079 - acc: 0.9992 - val_loss: 0.0606 - val_acc: 0.9808
Epoch 48/50
 - 9s - loss: 0.0074 - acc: 0.9992 - val_loss: 0.0585 - val_acc: 0.9818
Epoch 49/50
 - 8s - loss: 0.0066 - acc: 0.9994 - val_loss: 0.0595 - val_acc: 0.9816
Epoch 50/50
 - 10s - loss: 0.0062 - acc: 0.9995 - val_loss: 0.0622 - val_acc: 0.9810
Wall time: 6min 22s

Visualize Training Process

In [8]:
plt.rcParams["figure.figsize"] = [8,4]
plt.subplot(1,2,1)
plt.plot(h.history['acc'], label='Training')
plt.plot(h.history['val_acc'], label='Validation')
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()

plt.subplot(1,2,2)
plt.plot(h.history['loss'], label='Training')
plt.plot(h.history['val_loss'], label='Validation')
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend()
plt.show()
In [9]:
score = model.evaluate(X_test, y_test, verbose=0)
print('Testing loss:    ', score[0])
print('Testing accuracy:', score[1])
Testing loss:     0.10246038201781921
Testing accuracy: 0.9734

Confusion Matrix

In [10]:
from sklearn.metrics import confusion_matrix, classification_report

y_pred = model.predict_classes(X_test)

confmat = confusion_matrix(y_test, y_pred)
df = pd.DataFrame(confmat)

df
Out[10]:
0 1 2 3 4 5 6 7 8 9
0 495 0 0 0 1 0 0 0 1 2
1 0 572 1 0 0 1 0 2 0 0
2 2 1 518 2 2 0 1 4 3 0
3 0 0 1 503 0 0 0 3 0 0
4 2 0 2 0 467 0 3 0 0 8
5 0 0 0 3 1 393 1 0 2 1
6 3 0 1 0 0 2 471 0 0 0
7 1 2 5 0 1 1 0 508 2 2
8 4 3 7 6 1 5 2 2 477 2
9 0 3 0 3 8 1 0 15 1 463
In [11]:
print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       499
           1       0.98      0.99      0.99       576
           2       0.97      0.97      0.97       533
           3       0.97      0.99      0.98       507
           4       0.97      0.97      0.97       482
           5       0.98      0.98      0.98       401
           6       0.99      0.99      0.99       477
           7       0.95      0.97      0.96       522
           8       0.98      0.94      0.96       509
           9       0.97      0.94      0.95       494

   micro avg       0.97      0.97      0.97      5000
   macro avg       0.97      0.97      0.97      5000
weighted avg       0.97      0.97      0.97      5000

Exploring the misclassified digits:

In [12]:
import math
# Find misclassified samples in the test set.
sel = y_pred != y_test
n_mc = np.sum(sel) # number misclassified

X_mc = X_test[sel,:]
y_mc = y_test[sel]
yp_mc = y_pred[sel]

idx = np.argsort(y_mc)
X_mc = X_mc[idx,:]
y_mc = y_mc[idx]
yp_mc = yp_mc[idx]

rows = math.ceil(n_mc / 6)

plt.figure(figsize=(12,30))
for i in range(0, n_mc):
    plt.subplot(rows,6,i+1)
    plt.imshow(X_mc[i], cmap=cm.binary)
    plt.text(-1, 10, s = str(int(y_mc[i])), fontsize=16, color='b')
    plt.text(-1, 16, s = str(int(yp_mc[i])), fontsize=16, color='r')
    plt.axis('off')