Tuesday, May 28, 2013

Distance Metric Acceleration for all


I recently made a post about performing faster Euclidean Distance computations. I was working on some similar code for accelerating Kernel computations for SVMs, and realized I could adapt the same code to distance computations! I just committed that change (r704), and despite modifying a large number of files - it actually doesn't take a lot to exploit this ability without any silly code duplication.

The old way might look something like this (a quick made up example):

DistanceMetric dm = new EuclideanDistance();
List<Vec> X = dataSet.getDataVectors();
double allPairAvg = 0;
for(int i = 0; i < X.size(); i++)
    for(int j = i+1; j < X.size(); j++)
        allPairAvg += dm.dist(X.get(i), X.get(j));

allPairAvg /= Math.pow(X.size(), 2)/2;

Which would compute the average distance between all point pairs. This code above will still work in JSAT, but you can now write the code as follows:

DistanceMetric dm = new EuclideanDistance();
List<Vec> X = dataSet.getDataVectors();
List<Double> distCache = dm.getAccelerationCache(X);
double allPairAvg = 0;
for(int i = 0; i < X.size(); i++)
    for(int j = i+1; j < X.size(); j++)
        allPairAvg += dm.dist(i, j, X, distCache);

allPairAvg /= Math.pow(X.size(), 2)/2;

As you can see, there is not a huge change in code here. What happens is that the distCache holds pre-computed information about each of the vectors. In this case, the Euclidean Distance holds the self dot product. This is then used in the method call to accelerate the computation of distances. What if the metric does not support the acceleration calls? By the interface definition, getAccelerationCache will return null. And when the method is called, it must check if distCache is null. If so, it then uses the list to compute the distance the normal way.

This makes it just as fast as when not supported, as the distance call for unsupported classes will just make the two dereferences like the first version. distCache will be null, so tis just an extra void reference on the stack. Nothing big.

When the acceleration is supported, distCache uses the DoubleList class in JSAT, which is essentially just a normal array of doubles wrapped by another object, so the memory overhead is very small.

I did have to strike a bit of a balance in the interface. The one importance case is if you have a vector y that is not in the original data set, that you want to compute the distance to many of the points in the data set. Some use cases might just need a single distance, some might need to do all but only need the minimum result, others might only compare agains \(O(\log n)\) of the original vectors. So it needed to be decently versatile, and my solution ended up looking like this:

DistanceMetric dm = new EuclideanDistance();
List<Vec> X = dataSet.getDataVectors();
List<Double> distCache = dm.getAccelerationCache(X);
Vec y = //some vector from somewhere 
List<Double> qi = dm.getQueryInfo(y);
double yDistAvg = 0;
for(int i = 0; i < X.size(); i++)
    yDistAvg += dm.dist(i, y, qi, X, distCache);

yDistAvg /= X.size();

In this case line 5 gets its own information pre-computed, as if it was part of the original collection of vectors. We then simply provide that information when we do the distance computation  Once again, if the metric does not support acceleration, getQueryInfo will return null and the method computes the distance the normal way if distCache is null. In this way the same code works in all cases, and you dont have to have any special cases or branching in the code you write. The nature of the branching done behind the scenes is consistent though, and very easy for the CPU branch predictor, and even the JIT to eliminate all together.

There is a little overhead in using a double list is a bitch much since it is likely to have only one or two values stored in it, but its incredibly small relative to everything else going on - so I'm not worried about it.

To test it out, I re-ran some of the code from my previous experiment in seeing how fast k-means is. I've since implemented Hamerly's algorithm, and re-worked Elkan to be even more efficient  The two left values are the two without acceleration, and the right 3 values the speed with accretion  and the naive algorithm with acceleration.

Data Set
k
Elkan
Hamerly
Elkan w/ Cache
Hamerly w/ Cache
Naive w/ Cache
covtype
n = 581,012
d = 54
7
18.679 s
39.319 s
7.259 s
3.845 s
14.309 s
covtype
n = 581,012
d = 54
70
11 min 48 s
35 min 16 s
10 min 44 s
1 min 29 s
12 min 23 s
minst
n = 60,000
d = 780
10
51.874 s
2 min 10 s
2.57 s
3.365 s
10.765 s
minst
n = 60,000
d = 780
100
4 min 12 s
34 min 37 s
22.069 s
53.70 s
2 min 16 s

The point of Elkan's and Hamerly's algorithm is to avoid distance calculations, and the cache acceleration reduces that cost. This makes the naive algorithm surprisingly fast, much closer than it was before (I didn't feel like waiting hours for the unaccelerated version to run again). While no-longer orders of magnitudes faster, Elkan and Hamerly are only 2-10 times faster (which isn't bad!). An interesting case is Elkan's on the covtype data set for \(k = 70\). Amusingly, Elkan avoids so many distance computations in that case (and they have become so cheap), that the bookkeeping was becoming the most expensive part. This doesn't usually happen, but its an interesting case where Hamerly's becomes more efficient.

Overall, this code gives a huge speed improvement in a lot of cases. And now that the code is inside the Vector Collections, almost all the rest of JSAT and anyone who uses a VC will automatically get these speed boosts when supported.

No comments:

Post a Comment