Saturday, September 22, 2012

Vector Collections, KD & VP Tree timings

I've shown a bit of kNN and KDE, which both need to know how to query a metric space. While both of these algorithms are intrinsically dependent on the ability, many other algorithms in Machine Learning also make use of the ability to query a space of points and select the k-nearest or all the points within a certain range. So for this post, I'm going to go over a quick intro to the two most useful in JSAT.

JSAT encapsulates this functionality into a VectorCollection, so that the algorithm using this functionality can be independent of how this is done.

The naive way to do this, is of course, to linearly scan through the whole collection each time a request is made. The distance is computed for every point in the data set, and the points that matched are returned. If we have \(n\) points, this clearly takes \(O(n)\) time. There are also some unfortunate theoretical results about this. Given an arbitrary set of points in an arbitrary dimension, \(O(n)\) is the theoretical best we can do.

However, theory and practice rarely match up. Despite this result, we can get an improvement on many data sets. The most important factor of whether we do beat this result though is how high the dimension of our data set is. For problems of this nature, 10 is sadly a fairly high dimension. Too high often for the first method I will talk about, KD-Trees.

KD-Tree stands for K-Dimensional tree, and is a very popular algorithm. It recursively splits the space along one of the axis. By selecting the split so that half of the points fall on one side of the plain, and the other half on the opposite side, it hopes to allow you to ignore whole portions of the space. If you cut out half at every step, you have to go down \(O(\log n)\) trees to reach the bottom, which will hopefully be the nearest neighbor. Which axis to split? Well, you can just iterate through the axis in order, or select the axis with the most variance.

The intuition is fairly simple, but unfortunately - KD-trees only work well for small dimensions, usually less than 10. Part of the reason is because we split on axis, which are fixed, and can not adjust to the data at hand. When the dimensions get large, we may not even have enough points to even use all the axis, leaving a large amount of work left at the terminal nodes. When the stack then comes back up, the triangle inequality forces it to explore the other branches, unable to prune them out. Because of this extra work, KD-trees can be even slower than the naive algorithm. Selecting the attribute with the most variance also stops being helpful, as the variance is also constrained to the axis, which means it can't capture the true variance, and will result in noisy selection of the attribute to split on.

The next method I like, which is not widely used, are Vantage Point Trees (VP-trees). The logic is very similar, but the method is more general but also more expensive to build. At each level, a point from the data set is selected to tbe the Vantage Point ,the idea being it will have sight to all the other points at its node. The distance from each point to the VP is computed, and they are then sorted. The median value is selected as the splitting value, and two sub trees are formed. This median value is a radius, and we have formed a circle. Inside the radius are the points closer to the VP, and outside are the points farther. Again, we can see the halving gets us a tree with \(O(\log n)\) depth. We then use the triangle inequality to prune out inner and outer sides, hopefully answering queries in \(O(\log n)\) time.

Because we use the actual distances to build the tree, instead of axis, we get a method that works much better in higher dimensional spaces. However, we pay a price in construction time. We have to do \(O(n \log n) \) distance computations, where KD-Trees are also \(O(n \log n)\), no distance computations are needed at all. There is also the problem of how to select the Vantage Point? The originally proposed method is to select one by sampling the eligible points, and computing the VP that produces a lower variance in the sample. You can also select the VP randomly, and save a lot of expensive computation.

Another option I added, is how to make the split. Querying slows down when we are unable to successfully prune out branches, meaning our splits aren't adequately separating points that should be. Instead of selecting the median, we could instead select a radius that minimized the total variance of each sub node. This gives up the theoretical justification for expecting results in \(O(\log n)\) time, but hopefully makes the separations better. It shows up in the graph below as VP-Tree-MV

So, how do they compare on some decently large dimension problems then? I've computed the average build and query times for selecting the 1-NN on four different data sets from the UCI repository. Build times were calcualted from the training set of 9-fold cross validation, and query times from the testing fold. This way there were no distances of zero, and we used all the data to get a good result.

The graphs for these times are below. The title indicates the data set, and below it is (dimension) x (data set size).

We can see the Naive method has no build time, which makes sense. There is no work to do! We just have to have the data points. We can also see that VP-Trees are very slow to build with sampling. While we can see an improvement in query times, its not by much. If we know we have a problem where we will be doing much more querying for a long time, it might be worth it. Otherwise, the cost is a bit too much. We can also see VP-Tree-MV often has a small advantage, at relatively no cost. Though its performance is likely to be more variable across data sets. 

So, are VP-Trees always better? Usually for large dimensions. If you have a smaller dimensional problem, KD-Trees can perform a query using less resources. VP-Trees also have the advantage that they work for any distance metric, where KD-Trees are only valid for the p-norm metrics.

VP-Trees are similar to another algorithm called Ball-Trees. However, I've never found Ball-Trees to query faster than VP-Trees, and they are even more expensive to build. There also more complicated to implement, so I've only implemented VP-Trees in JSAT. Ball-Trees are very popular, but I've never understood why.

JSAT also has an implementation of R-Trees, however - they need specialized batch insertion methods to be efficient in practice. Because these methods are still slower than VP-Trees and Ball-Trees, and fairly complicated, I've never bothered to implement them.