Lesson 28 - K-Means Clustering for Image Compression

In [1]:
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

Load the Image

In [2]:
img = Image.open('Images/mosque.jpg')

plt.figure(figsize=[18,12])
plt.imshow(img)
plt.axis('off')
plt.show()
In [3]:
print(type(img))
<class 'PIL.JpegImagePlugin.JpegImageFile'>
In [4]:
img_array = np.array(img)
print(img_array.shape)
(1024, 678, 3)
In [5]:
X = img_array.reshape(1024 * 678, 3) / 255
print(X.shape)
(694272, 3)

Perform Clustering

In [6]:
from sklearn.cluster import MiniBatchKMeans
In [7]:
kmeans = MiniBatchKMeans(n_clusters=32)
kmeans.fit(X)

pred_clusters = kmeans.predict(X)
X_rc = kmeans.cluster_centers_[pred_clusters]
In [8]:
img_array_rc = X_rc.reshape(1024, 678, 3) * 255
img_array_rc = img_array_rc.astype('uint8')

image_rc = Image.fromarray(img_array_rc, 'RGB')

plt.figure(figsize=[18,12])
plt.imshow(img_array_rc)
plt.axis('off')
plt.show()
In [9]:
image_rc.save('images/mosque_rc.jpg')