Color Palettes Using K-Means Clustering

October 01, 2019

scroll for more

Share on:

Color Palettes Using K-Means Clustering

K-means clustering is an unsupervised learning technique that groups individual data points by identifying centroids such that each individual data point is assigned a cluster based on which centroid is closest to it in n-dimensional space. The goal of K-means clustering is to find the set of K centroids such that the difference between individual data points within a cluster is minimized. The cluster centroids therefore serve as the “prototype” of the cluster.

To illustrate this principle, we will apply K-means clustering to pixel data from an image. Every pixel will be treated as an individual data point with RGB data. By applying K-means clustering on the RGB values of image pixels, the algorithm will determine the cluster centroids which are representative of an image cluster, giving us the color palette of an image.

The code to implement K-means clustering using scikit-learn follows. An interactive version of this notebook can be found on Colab.

%matplotlib inline
# Import libraries

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from imageio import imread

from skimage.transform import resize

from sklearn.cluster import KMeans

from matplotlib.colors import to_hex

To avoid getting different color palettes from different K-means cluster seeds, the random seed is fixed.

# Fix random seed
np.random.seed(0)

Read the image file as 2-D array of RGB values.

# Read image file as 2-D array of RGB values
filepath = 'https://images.unsplash.com/photo-1522410818928-5522dacd5066'

img = imread(filepath)

# Show image
plt.axis('off')
plt.imshow(img);

img

Resize the image to a 200 by 200 pixel image.

img = resize(img, (200, 200))

Get each pixel as an array of RGB values.

data = pd.DataFrame(img.reshape(-1, 3),
                    columns=['R', 'G', 'B'])

Cluster the pixels into 5 colors based on the RGB value.

kmeans = KMeans(n_clusters=5,
                random_state=0)
# Fit and assign clusters
data['Cluster'] = kmeans.fit_predict(data)

Get the color palette from the cluster centers.

palette = kmeans.cluster_centers_
# Convert data to format accepted by imshow
palette_list = list()
for color in palette:
    palette_list.append([[tuple(color)]])

Show the color palette, along with the hexadecimal code for the color.

# Show color palette
for color in palette_list:
    print(to_hex(color[0][0]))
    plt.figure(figsize=(1, 1))
    plt.axis('off')
    plt.imshow(color);
    plt.show();

#c1a4a9 png

#38293c png

#f2ebedpng

#df8c18png

#1565a1png

Recreate the image using only colors from color palette.

# Replace every pixel's color with the color of its cluster centroid
data['R_cluster'] = data['Cluster'].apply(lambda x: palette_list[x][0][0][0])
data['G_cluster'] = data['Cluster'].apply(lambda x: palette_list[x][0][0][1])
data['B_cluster'] = data['Cluster'].apply(lambda x: palette_list[x][0][0][2])
# Convert the dataframe back to a numpy array
img_c = data[['R_cluster', 'G_cluster', 'B_cluster']].values
# Reshape the data back to a 200x200 image
img_c = img_c.reshape(200, 200, 3)
# Resize the image back to the original aspect ratio
img_c = resize(img_c, (800, 1200))
# Display the image
plt.axis('off')
plt.imshow(img_c)
plt.show()

png

Voila!

Share on:

Read more posts about...

See all posts