Sunday, March 29, 2015

JSAT vs Weka on MNIST

I'm a strong proponent of using the right tool for the job, and I need to prefix this post with that fact that Java is not the best tool for Machine Learning in general. But you may know java best, you may need to interface with a Java infrastructure, or you might just know / prefer Java over other languages. For any number of reasons you may need or want to do some ML work in Java, and Weka is the choice that most people hear first.

My problem, is that most of Weka is inexcusably slow - and doesn't provide enough algorithms for you to really chose the right tool for your job. The code base is massive for the few algorithms it does support, and its bloated with code duplication and inefficient design. The biggest time suck in ML is usually model building / testing, and so I've put together a small benchmark on MNIST.

MNIST, as a data set, is often overused. However it its just large enough that if your code is inefficient you will feel it. We also know very well what accuracies are obtainable with various models on MNIST, making it a good sanity check. So I downloaded an ARFF version of the data from here and made this benchmark (hosted here). The purpose of this benchmark is to show that JSAT has significantly faster implementations of many of the same algorithms, and has better algorithm to use as well.

For this benchmark I've tried to keep everything apples-to-apples, and did my best to make sure each algorithm was doing the same thing and had the same options set. Most of the cases I made JSAT use the parameters Weka uses by default when possible. For some algorithms (like SVMs) that need a good set of parameters, I used values from a grid search I did before with LIBSVM.

All the times below are presented in seconds, run on my iMac (2.66 Ghz Core i5, 16GB of RAM) with nothing else. I've omitted the train/test time for algorithms where training/testing is not a performance issue.

Algorithm Weka Time Weka Error JSAT Time JSAT Error JSAT Speedup
SVM w/ RBF Kernel (Full Cache)
Train: 7713.549 
Test: 1339.633 
Train: 3661.7
Test: 337.765
Train: 2.1x
Test: 4.0x
SVM w/ RBF Kernel (No Cache)
Train: 5657.924 
Test: 1336.557
Train: 2558.846
Test: 317.663
Train: 2.2x
Test: 4.2x
RBF SVM stochastic w/ 5 iterations
Train: 518.654
Test: 10.443
Train: 10.9x
Test: 128.0x
(Over SVM)
RBF SVM RKS features w/ Linear Solver
Train: 68.398
Test: 0.571
Train: 82.7x
Test: 2340x
(Over SVM)
C4.5 Decision Tree Train: 303.373 0.1134 Train: 117.785 0.1146 Train: 2.6x
Random Forest w/ 50 trees Tain: 143.127 0.0326 Train: 100.673 0.0453 Train: 1.4x
1-NN (brute force) Test: 2537.483 0.0309 Test: 648.71 0.0309 Test: 3.9x
1-NN (Ball Tree)
Train: 52.263
Test: 3269.183  
1-NN (Cover Tree)
Train: 538.132
Test: 2245.709 
1-NN (VPmv)
Train: 1.909
Test: 493.69
1-NN (Random Ball Cover)
Train: 13.448 
Test: 576.737
Logistic Regression by LBFGS, λ = 1e-4 Train: 3301.899 0.0821 Train: 907.259 0.0776 Train: 3.6x
Log Regression stochastic w/ 10 iterations Train: 10.545 0.0840
Train: 313x
(over LBFGS)
Logistic Regression OneVsAll DCD Train: 276.865 0.080
Train: 12.3x
(over LBFGS)
(lloyd's algorithm)
1010.6009 41.1913 24.5x
(Hamerly's algorithm)
10.5358 95.9x
(Elkan's algorithm)

When running the same algorithms, most results end up with about the exact same accuracy. For the Random Forest JSAT does a little worse for some reason, and I can't find a setting in Weka to change. For Logistic Regression JSAT does a little better. However, in the case of Random Forest, you could trade some of your speed savings for a few more trees to make up the difference.

For SVMs, the speed advantage is about 2x for training and 4x for prediction. Speed advantages for the SVM are particularly important since the SVM is very sensitive to parameters, so a grid search is going to be needed, multiplying the runtime by a factor of 10x-100x depending on how many parameter combinations you want to test. The alternative algorithms in JSAT (Stochastic Kernel training directly & approximate feature space) up the advantage even further.

For the tree based, JSAT's advantage isn't as big. But if you are going to use trees in an ensemble (As is often the case), the 2.6x speedup is going to add up.

For Nearest Neighbor algorithms, both Weka and JSAT have data structures for accelerating nearest neighbor queries. While both of the ones from JSAT improved time and were fast to train, the Ball Tree algorithm in Weka was slower than the naive approach - and the Cover tree was only a little faster, but took almost as much time to train as JSAT did to do the whole problem to begin with!

For Logistic Regression, JSAT is again single digits faster, 3.6x. While LBFGS is a great and versatile tool to have, its not always the best choice especially for a bread-and-butter algorithm like Logistic Regression. JSAT's exact alternative exact solver (DCD) is the same algorithm used in LIBLINEAR, and is over 12x faster. However you don't always need an exact solution, and the SGD based solver in JSAT gets a solution almost as good as Weka's LBFGS and gets it 313 times faster.

Finally, for k-means, JSAT is already 24 times faster using the exact same algorithm. But JSAT also has two additional algorithms that obtain the exact same solution, but avoid redundant work - resulting in 95x to almost 200x faster training. With k-means being such a common tool both in use and as a building block for other algorithms, the difference is huge.

So, overall - JSAT has faster implementations of the same algorithms, and more algorithms - giving you the flexibility to pick the right tool for the job.

Just to be clear, I'm not claiming that JSAT has the fastest implementations ever of these algorithms. Many great tools, such as scikit-learn, have faster implementations for some algorithms. But JSAT does have a greater variety and if you are going to be working in Java, you should definitely consider JSAT over Weka just on the issue of speed.

Monday, March 16, 2015

Improving dataset loading in JSAT/Java

I've had a loader for the LIBSVM format of storing data sets for a while now. I recently ran into a rather fun issue, where in loading a rather large data set I got an error. I didn't read the error, but the first thing I did was run it again with the debugger on - waiting for it to break on wherever the error was. And it ran through perfectly!

Turns out I was getting a GC Overhead limit exception. This happens when the JVM spends too much time doing GC work, and just kills your job. By running in debug mode, I slowed down how quickly new garbage was created, which gave the GC enough extra head room to make it through.

While amusing, I did need to fix this - and I encountered something surprising. There doesn't seem to be any good ware to turn a subset of a character sequence into a floating point value. The only method I could find was to run Double.parseDouble() on a sub string, but that meant creating a new garbage object! So I had to write that, and added code to read in the into a buffer, and walk through the buffer for the values I needed, re-filling when necessary.

Here is a graph of the memory use of loading a file with the original LIBSVM code. The first graph just shows memory use. The second shows GC overhead and the number of collections.

Not horrible, but this was on a small data set. On larger ones (like the 100GB file I encountered the error with) this is a problem. Below are the same graphs for the new code. 

That looks much better! It also looks more like what some people would expect had the code be written in C/C++ instead of Java, what was one of the reasons I wanted to show it. The GC in Java is a great tool, but that doesn't mean we should always rely on it. When necessary, there isn't much stopping us from writing Java code that looks like C/C++ and behaves very similarly. 

Current'y I'm not super happy with the new code, as I wrote it in a hurry at home so I could used the updated code at work. If I go back and re-write it again, I'll probably make a small finite state machine so that the logic is easier to follow. But it does perform much better! 

I also had to implement my own parsing of a character sequence to a double to get this performance, and this has the unfortunate side effect that you don't necessarily get the exact same double value you would have gotten if you used the standard library. This is something I need to fix, and an artifact of how surprisingly involved dealing with floating point always is. However the current code's relative error is always less than 10-14, which means the ML code will all be fine. But its still not a nice surprise. 

I also want to take a moment and complain about ASCII data sets. For most machine learning data sets, a common binary representation would take of less space, load faster, and be easier to handle. I don't understand everyones fascination with making things JSON or XML or whatever other text format they are using. Its almost always unneeded and added complexity for the sole benefit of being "human readable", even if a human isn't ever going to read it!