K-means clustering is an unsupervised learning technique that groups individual data points by identifying centroids such that each individual data point is assigned a cluster based on which centroid is closest to it in n-dimensional space. The goal of K-means clustering is to find the set of K centroids such that the difference between individual data points within a cluster is minimized. The cluster centroids therefore serve as the “prototype” of the cluster.
To illustrate this principle, we will apply K-means clustering to pixel data from an image. Every pixel will be treated as an individual data point with RGB data. By applying K-means clustering on the RGB values of image pixels, the algorithm will determine the cluster centroids which are representative of an image cluster, giving us the color palette of an image.
The code to implement K-means clustering using scikit-learn follows. An interactive version of this notebook can be found on Colab.
%matplotlib inline
# Import libraries
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from imageio import imread
from skimage.transform import resize
from sklearn.cluster import KMeans
from matplotlib.colors import to_hex
To avoid getting different color palettes from different K-means cluster seeds, the random seed is fixed.
# Fix random seed
np.random.seed(0)
Read the image file as 2-D array of RGB values.
# Read image file as 2-D array of RGB values
filepath = 'https://images.unsplash.com/photo-1522410818928-5522dacd5066'
img = imread(filepath)
# Show image
plt.axis('off')
plt.imshow(img);
Resize the image to a 200 by 200 pixel image.
img = resize(img, (200, 200))
Get each pixel as an array of RGB values.
data = pd.DataFrame(img.reshape(-1, 3),
columns=['R', 'G', 'B'])
Cluster the pixels into 5 colors based on the RGB value.
kmeans = KMeans(n_clusters=5,
random_state=0)
# Fit and assign clusters
data['Cluster'] = kmeans.fit_predict(data)
Get the color palette from the cluster centers.
palette = kmeans.cluster_centers_
# Convert data to format accepted by imshow
palette_list = list()
for color in palette:
palette_list.append([[tuple(color)]])
Show the color palette, along with the hexadecimal code for the color.
# Show color palette
for color in palette_list:
print(to_hex(color[0][0]))
plt.figure(figsize=(1, 1))
plt.axis('off')
plt.imshow(color);
plt.show();
#c1a4a9
#38293c
#f2ebed
#df8c18
#1565a1
Recreate the image using only colors from color palette.
# Replace every pixel's color with the color of its cluster centroid
data['R_cluster'] = data['Cluster'].apply(lambda x: palette_list[x][0][0][0])
data['G_cluster'] = data['Cluster'].apply(lambda x: palette_list[x][0][0][1])
data['B_cluster'] = data['Cluster'].apply(lambda x: palette_list[x][0][0][2])
# Convert the dataframe back to a numpy array
img_c = data[['R_cluster', 'G_cluster', 'B_cluster']].values
# Reshape the data back to a 200x200 image
img_c = img_c.reshape(200, 200, 3)
# Resize the image back to the original aspect ratio
img_c = resize(img_c, (800, 1200))
# Display the image
plt.axis('off')
plt.imshow(img_c)
plt.show()
Voila!