k-means is a fairly simple algorithm, and works as follows. \(k\) is the number of clusters we want, and is specified before hand.
- By some method, select \(k\) different points from the data set to be the initial \(k\) means.
- For each data point, find the closest mean \(k_i\) and assign the point to that mean.
- For each mean, recompute the mean from all the points assigned to it.
- If any of the points changed clusters, repeat from step 2.
That is it! Despite its simplicity and naivety, the algorithm is used all the time for many different purposes. This particular naive implementation is often known as Lloyd's algorithm.
The main problem with this algorithm is that it can be a bit slow, and a lot of work has gone into implementing faster versions. Some try to get an exact result, and others attempt to get an approximate solution. Of the exact solutions, the one I see most often compared against is an algorithm by Hartigan and Wong from 1979. H&W report their algorithm to be up to twice as fast as the naive implementation, and its one of the most common benchmarks to beat when showing a new approximate or exact algorithm.
This has always bother me though, especially thanks to an algorithm by Elkan in 2003 that computers the exact k-means result by using the triangle inequality to prune out distance computations. In JSAT the naive algorithm is implemented as NaiveKMeans, and K-Means is an implementation of Elkan's algorithm.
Bellow are some performance comparisons between the two, with \(k\) equal to the number of classes in the data set times 10. The seeds were selected in a deterministic manner, so that the initial points wouldn't add variance to the run time. The Euclidean distance was used for all problems.
Data Set | k | Lloyd Run Time | Elkan Run Time | Elkan Speed Up |
---|---|---|---|---|
a9a n = 32,561 d = 123 |
2
|
2.868 s
|
1.206 s
|
2.38
|
a9a n = 32,561 d = 123 |
20
|
24.015 s
|
6.928 s
|
3.47
|
covtype n = 581,012 d = 54 |
7
|
10 min 45.696 s
|
46.005 s
|
14
|
covtype n = 581,012 d = 54 |
70
|
8 hrs 13 min 8.75 s
|
16 min 57.652 s
|
29
|
minst n = 60,000 d = 780 |
10
|
15 min 22.244 s
|
2 min 23.583 s
|
6.42
|
minst n = 60,000 d = 780 |
100
|
5 hrs 32 min 6.34 s
|
8 min 40 s
|
38
|
As you can see, Elkan's algorithm is faster - and much better behaved as we increased \(k\)! So why don't more people implement Elkan's? The only two minor downsides to the algorithm are:
- Its harder to implement in parallel, but JSAT gets linear speeds on my machines using a produce/consumer model just fine
- It requires \(O(n \cdot k +k^2)\) extra memory. However, unless \(k\) is getting very large, its really not an unbearable amount extra. Given the enormousness time savings and the cheapness of memory, I'd call that a good trade off.
Even still, there is another paper by Greg Hamerly that builds on Elkan's work. It is faster for low dimensional problems, and still very fast for higher dimensional ones. But more importantly, only needs \(O(n)\) extra memory.
I haven't implemented Hamerly's algorithm yet (its on my giant TODO list). But I would like to see comparisons done against these two algorithms when people publish. In fact, in one of my next posts - I'll be doing this comparison for you!