K-means clustering for image reduction
K-means algorithm
The K-means algorithm is a method to automatically cluster similar data examples together. Concretely, given a training set and want to group the data into a few cohesive "clusters".
The intuition behind K-means is an iterative procedure that starts by guessing the initial centroids, and then refines this guess by repeatedly assigning examples to their closest centroids and then recomputing the centroids based on the assignments.
The algorithm repeatedly carries out two steps:
- Assigning each training example to its closest centroid
- Recomputing the mean of each centroid using the points assigned to it.
The K-means algorithm will always converge to some final set of means for the centroids. Note that the converged solution may not always be ideal and depends on the initial setting of the centroids. Therefore, in practice the K-means algorithm is usually run a few times with different random initializations. One way to choose between these different solutions from different random initializations is to choose the one with the lowest cost function value (distortion).
Finding closest centroids
In the cluster assignment" phase of the K-means algorithm, the algorithm assigns every training example to its closest centroid, given the current positions of centroids. Specifically, for every example i we set
where is the index of the centroid that is closest to and is the position (value) of the j’th centroid.
fprintf('Finding closest centroids.\n\n');
% Load an example dataset that we will be using
load('data.mat');
% Select an initial set of centroids
K = 3; % 3 Centroids
initial_centroids = [3 3; 6 2; 8 5];
% Find the closest centroids for the examples using the
% initial_centroids
idx = findClosestCentroids(X, initial_centroids);
%Implementation of findClosestCentroids is at the end section
fprintf('Closest centroids for the first 3 examples: \n')
fprintf(' %d', idx(1:3));
The function findClosestCentroids takes the data matrix X and the locations of all centroids inside centroids and output a one-dimensional array idx that holds the index (a value in {1,...,K} where K is total number of centroids) of the closest centroid to every training example.
Computing centroid means
Given assignments of every point to a centroid, the second phase of the algorithm recomputes, for each centroid, the mean of the points that were assigned to it. Specifically, for every centroid k we set
where is the set of examples that are assigned to centroid k. Concretely, if two examples say and are assigned to centroid k = 2, then you should update .
fprintf('\nComputing centroids means.\n\n');
% Compute means based on the closest centroids found in the previous part.
centroids = computeCentroids(X, idx, K);
fprintf('Centroids computed after initial finding of closest centroids: \n')
fprintf(' %f %f \n' , centroids');
Visualizing K-means
Let's run K-means on example dataset
fprintf('\nRunning K-Means clustering on example dataset.\n\n');
% Load an example dataset
load('data.mat');
% Settings for running K-Means
K = 3;
max_iters = 10;
% For consistency, here we set centroids to specific values
% but in practice you want to generate them automatically, such as by
% settings them to be random examples (as can be seen in
% kMeansInitCentroids).
initial_centroids = [3 3; 6 2; 8 5];
% Run K-Means algorithm. The 'true' at the end tells our function to plot
% the progress of K-Means
[centroids, idx] = runkMeans(X, initial_centroids, max_iters, true);
fprintf('\nK-Means Done.\n\n');
Image compression with K-means
In a straightforward 24-bit color representation of an image each pixel is represented as three 8-bit unsigned integers (ranging from 0 to 255) that specify the red, green and blue intensity values. This encoding is often refered to as the RGB encoding. Our image contains thousands of colors, and in this part of the exercise, you will reduce the number of colors to 16 colors.
By making this reduction, it is possible to represent (compress) the photo in an efficient way. Specifically, you only need to store the RGB values of the 16 selected colors, and for each pixel in the image you now need to only store the index of the color at that location (where only 4 bits are necessary to represent 16 possibilities).
We will use the K-means algorithm to select the 16 colors that will be used to represent the compressed image. Concretely, we will treat every pixel in the original image as a data example and use the K-means algorithm to find the 16 colors that best group (cluster) the pixels in the 3-dimensional RGB space. Once we have computed the cluster centroids on the image, we will then use the 16 colors to replace the pixels in the original image.
% Load an image of a bird
A = double(imread('bird_small.png'));
A = A / 255; % Divide by 255 so that all values are in the range 0 - 1
% Size of the image
img_size = size(A);
% Reshape the image into an Nx3 matrix where N = number of pixels.
% Each row will contain the Red, Green and Blue pixel values
% This gives us our dataset matrix X that we will use K-Means on.
X = reshape(A, img_size(1) * img_size(2), 3);
% Run K-Means algorithm on this data
K = 16;
max_iters = 10;
% When using K-Means, it is important the initialize the centroids
% randomly.
initial_centroids = kMeansInitCentroids(X, K);
% Run K-Means
[centroids, idx] = runkMeans(X, initial_centroids, max_iters);
K-means on pixels
A is a three-dimensional matrix whose first two indices identify a pixel position and whose last index represents red, green, or blue. For example, A(50, 33, 3) gives the blue intensity of the pixel at row 50 and column 33.
After loading the image we reshape it to create an m × 3 matrix of pixel colors (where m = 16384 = 128 × 128), and call our K-means function on it.
% Find closest cluster members
idx = findClosestCentroids(X, centroids);
% Essentially, now we have represented the image X as in terms of the
% indices in idx.
% We can now recover the image from the indices (idx) by mapping each pixel
% (specified by its index in idx) to the centroid value
X_recovered = centroids(idx,:);
% Reshape the recovered image into proper dimensions
X_recovered = reshape(X_recovered, img_size(1), img_size(2), 3);
% Display the original image
subplot(1, 2, 1);
imagesc(A);
title('Original');
% Display compressed image side by side
subplot(1, 2, 2);
imagesc(X_recovered)
title(sprintf('Compressed, with %d colors.', K));
After finding the top K = 16 colors to represent the image, we can now assign each pixel position to its closest centroid using the findClosestCentroids function. This allows us to represent the original image using the centroid assignments of each pixel. Notice that we have significantly reduced the number of bits that are required to describe the image. The original image required 24 bits for each one of the 128×128 pixel locations, resulting in total size of 128 × 128 × 24 = 393; 216 bits. The new representation requires some overhead storage in form of a dictionary of 16 colors, each of which require 24 bits, but the image itself then only requires 4 bits per pixel location. The final number of bits used is therefore 16 × 24 + 128 × 128 × 4 = 65; 920 bits, which corresponds to compressing the original image by about a factor of 6.
Finally, we can view the effects of the compression by reconstructing the image based only on the centroid assignments. Specifically, we can replace each pixel location with the mean of the centroid assigned to it. Figure shows the reconstruction we obtained. Even though the resulting image retains most of the characteristics of the original, we also see some compression artifacts.
Implementation of functions
Implementations of kMeansInitCentroids
function centroids = kMeansInitCentroids(X, K)
%KMEANSINITCENTROIDS This function initializes K centroids that are to be
%used in K-Means on the dataset X
% centroids = KMEANSINITCENTROIDS(X, K) returns K initial centroids to be
% used with the K-Means on the dataset X
%
% You should return this values correctly
centroids = zeros(K, size(X, 2));
% Initialize the centroids to be random examples
% Randomly reorder the indices of examples
randidx = randperm(size(X, 1));
% Take the first K examples as centroids
centroids = X(randidx(1:K), :);
end
Implementation of findClosestCentroids
function idx = findClosestCentroids(X, centroids)
%FINDCLOSESTCENTROIDS computes the centroid memberships for every example
% idx = FINDCLOSESTCENTROIDS (X, centroids) returns the closest centroids
% in idx for a dataset X where each row is a single example. idx = m x 1
% vector of centroid assignments (i.e. each entry in range [1..K])
%
% Set K
K = size(centroids, 1);
% You need to return the following variables correctly.
idx = zeros(size(X,1), 1);
for i = 1:size(X,1)
min_cost = intmax;
indx = -1;
for j = 1:K
cost = sum((X(i,:) - centroids(j,:)).^2);
if cost<min_cost
indx = j;
min_cost = cost;
end
end
idx(i) = indx;
end
end
Implementation of computeCentroids
function centroids = computeCentroids(X, idx, K)
%COMPUTECENTROIDS returns the new centroids by computing the means of the
%data points assigned to each centroid.
% centroids = COMPUTECENTROIDS(X, idx, K) returns the new centroids by
% computing the means of the data points assigned to each centroid. It is
% given a dataset X where each row is a single data point, a vector
% idx of centroid assignments (i.e. each entry in range [1..K]) for each
% example, and K, the number of centroids. You should return a matrix
% centroids, where each row of centroids is the mean of the data points
% assigned to it.
%
% Useful variables
[m n] = size(X);
% You need to return the following variables correctly.
centroids = zeros(K, n);
count = zeros(K,1);
for i = 1:m
grp = idx(i);
centroids(grp,:) = centroids(grp,:)+ X(i,:);
count(grp) = count(grp) + 1;
end
for i = 1:K
if count(i)>0
centroids(i,:) = centroids(i,:)/count(i);
end
end
end
Implementation of runkMeans
function [centroids, idx] = runkMeans(X, initial_centroids, ...
max_iters, plot_progress)
%RUNKMEANS runs the K-Means algorithm on data matrix X, where each row of X
%is a single example
% [centroids, idx] = RUNKMEANS(X, initial_centroids, max_iters, ...
% plot_progress) runs the K-Means algorithm on data matrix X, where each
% row of X is a single example. It uses initial_centroids used as the
% initial centroids. max_iters specifies the total number of interactions
% of K-Means to execute. plot_progress is a true/false flag that
% indicates if the function should also plot its progress as the
% learning happens. This is set to false by default. runkMeans returns
% centroids, a Kxn matrix of the computed centroids and idx, a m x 1
% vector of centroid assignments (i.e. each entry in range [1..K])
%
% Set default value for plot progress
if ~exist('plot_progress', 'var') || isempty(plot_progress)
plot_progress = false;
end
% Plot the data if we are plotting progress
%if plot_progress
%figure;
%hold on;
%end
% Initialize values
[m n] = size(X);
K = size(initial_centroids, 1);
centroids = initial_centroids;
previous_centroids = centroids;
idx = zeros(m, 1);
% Run K-Means
for i=1:max_iters
% Output progress
fprintf('K-Means iteration %d/%d...\n', i, max_iters);
if exist('OCTAVE_VERSION')
fflush(stdout);
end
% For each example in X, assign it to the closest centroid
idx = findClosestCentroids(X, centroids);
% Optionally, plot progress here
if plot_progress
if i==1
figure();
else
f = gcf;
f2 = figure();
f2 = copy(f);
end
end
plotProgresskMeans(X, centroids, previous_centroids, idx, K, i);
previous_centroids = centroids;
%fprintf('Press enter to continue.\n');
% pause
% Given the memberships, compute new centroids
centroids = computeCentroids(X, idx, K);
end
%Hold off if we are plotting progress
%if plot_progress
%hold off;
%end
end
Implementation of plotProgresskMeans
function plotProgresskMeans(X, centroids, previous, idx, K, i)
%PLOTPROGRESSKMEANS is a helper function that displays the progress of
%k-Means as it is running. It is intended for use only with 2D data.
% PLOTPROGRESSKMEANS(X, centroids, previous, idx, K, i) plots the data
% points with colors assigned to each centroid. With the previous
% centroids, it also plots a line between the previous locations and
% current locations of the centroids.
%
% Plot the examples
hold on;
plotDataPoints(X, idx, K);
% Plot the centroids as black x's
plot(centroids(:,1), centroids(:,2), 'x', ...
'MarkerEdgeColor','k', ...
'MarkerSize', 10, 'LineWidth', 3);
% Plot the history of the centroids with lines
for j=1:size(centroids,1)
drawLine(centroids(j, :), previous(j, :));
end
hold off;
% Title
title(sprintf('Iteration number %d', i))
end
Implementation of plotDataPoints
function plotDataPoints(X, idx, K)
%PLOTDATAPOINTS plots data points in X, coloring them so that those with the same
%index assignments in idx have the same color
% PLOTDATAPOINTS(X, idx, K) plots data points in X, coloring them so that those
% with the same index assignments in idx have the same color
% Create palette
palette = hsv(K + 1);
colors = palette(idx, :);
% Plot the data
scatter(X(:,1), X(:,2), 15, colors);
end