Target audience: Intermediate
Estimated reading time: 6'
This article serves as the continuation of our journey into the Java-based implementation of K-means clustering. Building upon the components outlined in the preceding post, we now focus on constructing a classifier.
Overview
The basic components of the implementation of K-means clustering algorithms has been introduced in the previous post K-means clustering in Java: Components
This section describes the implementation of training and inference tasks for the model:
- training: executed off-line during analysis of historical data
- classification: executed at run-time to classify new obsdervations
Note:
For the sake of readability of the implementation of algorithms, all non-essential code such as error checking, comments, exception, validation of class and method arguments, scoping qualifiers or import is omitted.
Training
The learning method, train, implements the clustering algorithm. It iterates to minimize the sum of distances between all cluster data points & its centroid.For each iteration (or epoch) the train method:
- assign observations to each cluster
- compute the centroid for each cluster, computeCentroid
- compute the total distance of all the observations with their respective centroid computeTotalDistance
- estimate the closest cluster for each observation
- re-assign the observation, updateCentroids
public int train() {
int numIter = _maxIters, k = 0
boolean inProgress = true;
initialize();
while(inProgress) {
for(KmeansCluster cluster : _clusters ) {
cluster.attach(_obsList[k]);
if( ++k >= _obsList.length) {
inProgress = false;
break;
}
}
}
computeTotalDistance();
for(KmeansCluster cluster : _clusters ) {
cluster.computeCentroid();
}
computeTotalDistance();
List<Observation> obsList = null;
KmeansCluster closestCluster = null;
// main iterative method, that traverses all the clusters
// computes the distance of observations relative to their centroid
// and re-assign the observations.
for(int i = 0; i < _maxIterations; i++) {
for(KmeansCluster cluster : _clusters ) {
obsList = new ArrayList<Observation>();
for( Observation point : cluster.getDataPointsList()) {
obsList.add(point);
}
for( Observation point : obsList) {
double minDistance = Double.MAX_VALUE, distance = 0.0;
closestCluster = null;
for(KmeansCluster cursor : _clusters ) {
distance = point.computeDistance(cursor.getCentroid());
if( minDistance > distance) {
minDistance = distance;
closestCluster = cursor;
}
}
updateCentroids(point, cluster, closestCluster);
}
}
// Simple convergence criteria
if( _convergeCounter >= _minNumConvergeIters ) {
numIters= i;
break;
}
}
return numIters;
}
Classification
The classification of a new observations is simple and consists in evaluating the distance between the new data point and each centroid and selecting the cluster with the shortest distance. The classify method extract the index or label of the cluster that is the most suitable (closest in distance) to the new observation.
public int classify(double[] obs) {
double bestScore = Double.MAX_VALUE, distance = 0.0;
int clusterId = -1;
for(int k = 0; k < _centroids.length; k++) {
distance = _centroids[k].computeDistance(obs);
if(distance < bestScore) {
bestScore = distance;
clusterId = k;
}
}
return clusterId;
}
The code snippet below implements some of the supporting method to
- compute the loss function value (total distance) - initialize the centroid for each cluster - update the values of centroids.
private void computeTotalDistance() {
float totalDistance = 0.0F;
for(KmeansCluster cluster : _clusters ) {
totalDistance += cluster.getSumDistances();
}
double error = Math.abs(_totalDistance - totalDistance);
convergeCounter = ( error < _convergeCriteria) ? convergeCounter +1 : 0;
_totalDistance = totalDistance;
}
private void initialize() {
double[] params = getParameters();
int numVariables = params.length>>1
double[] range = new double[numVariables];
for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) {
range[k] = params[k] - params[j];
}
double[] x = new double[numVariables];
int sz_1 = _clusters.length+1, m = 1;
for(KmeansCluster cluster : _clusters) {
for( int k = 0, j = numVariables; k <numVariables; k++, j++ ) {
x[k] = ((range[k]/sz_1)*m) + params[j];
}
cluster.setCentroid(x);
m++;
}
}
private void updateCentroids(Observation point,
KmeansCluster cluster,
KmeansCluster bestCluster) {
boolean update = bestCluster != null && bestCluster != cluster;
if( update ) {
bestCluster.attach(point);
cluster.detach(point);
for(KmeansCluster cursor : _clusters ) {
cursor.computeCentroid();
}
computeTotalDistance();
}
}
Thank you for reading this article. For more information ...
References
- The Elements of Statistical Learning - T. Hastie, R.Tibshirani, J. Friedman - Springer 2001
- Machine Learning: A Probabilisitc Perspective 11.4.2.5 K-means algorithm - K. Murphy - MIT Press 2012
- Pattern Recognition and Machine Learning: Chap 9 "Mixture Models and EM: K-means Clustering" C.Bishop - Springer Science 2006
- github.com/patnicolas
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.