Sunday, March 29, 2015

JSAT vs Weka on MNIST

I'm a strong proponent of using the right tool for the job, and I need to prefix this post with that fact that Java is not the best tool for Machine Learning in general. But you may know java best, you may need to interface with a Java infrastructure, or you might just know / prefer Java over other languages. For any number of reasons you may need or want to do some ML work in Java, and Weka is the choice that most people hear first.

My problem, is that most of Weka is inexcusably slow - and doesn't provide enough algorithms for you to really chose the right tool for your job. The code base is massive for the few algorithms it does support, and its bloated with code duplication and inefficient design. The biggest time suck in ML is usually model building / testing, and so I've put together a small benchmark on MNIST.

MNIST, as a data set, is often overused. However it its just large enough that if your code is inefficient you will feel it. We also know very well what accuracies are obtainable with various models on MNIST, making it a good sanity check. So I downloaded an ARFF version of the data from here and made this benchmark (hosted here). The purpose of this benchmark is to show that JSAT has significantly faster implementations of many of the same algorithms, and has better algorithm to use as well.

For this benchmark I've tried to keep everything apples-to-apples, and did my best to make sure each algorithm was doing the same thing and had the same options set. Most of the cases I made JSAT use the parameters Weka uses by default when possible. For some algorithms (like SVMs) that need a good set of parameters, I used values from a grid search I did before with LIBSVM.

All the times below are presented in seconds, run on my iMac (2.66 Ghz Core i5, 16GB of RAM) with nothing else. I've omitted the train/test time for algorithms where training/testing is not a performance issue.

Algorithm Weka Time Weka Error JSAT Time JSAT Error JSAT Speedup
SVM w/ RBF Kernel (Full Cache)
Train: 7713.549 
Test: 1339.633 
Train: 3661.7
Test: 337.765
Train: 2.1x
Test: 4.0x
SVM w/ RBF Kernel (No Cache)
Train: 5657.924 
Test: 1336.557
Train: 2558.846
Test: 317.663
Train: 2.2x
Test: 4.2x
RBF SVM stochastic w/ 5 iterations
Train: 518.654
Test: 10.443
Train: 10.9x
Test: 128.0x
(Over SVM)
RBF SVM RKS features w/ Linear Solver
Train: 68.398
Test: 0.571
Train: 82.7x
Test: 2340x
(Over SVM)
C4.5 Decision Tree Train: 303.373 0.1134 Train: 117.785 0.1146 Train: 2.6x
Random Forest w/ 50 trees Tain: 143.127 0.0326 Train: 100.673 0.0453 Train: 1.4x
1-NN (brute force) Test: 2537.483 0.0309 Test: 648.71 0.0309 Test: 3.9x
1-NN (Ball Tree)
Train: 52.263
Test: 3269.183  
1-NN (Cover Tree)
Train: 538.132
Test: 2245.709 
1-NN (VPmv)
Train: 1.909
Test: 493.69
1-NN (Random Ball Cover)
Train: 13.448 
Test: 576.737
Logistic Regression by LBFGS, λ = 1e-4 Train: 3301.899 0.0821 Train: 907.259 0.0776 Train: 3.6x
Log Regression stochastic w/ 10 iterations Train: 10.545 0.0840
Train: 313x
(over LBFGS)
Logistic Regression OneVsAll DCD Train: 276.865 0.080
Train: 12.3x
(over LBFGS)
(lloyd's algorithm)
1010.6009 41.1913 24.5x
(Hamerly's algorithm)
10.5358 95.9x
(Elkan's algorithm)

When running the same algorithms, most results end up with about the exact same accuracy. For the Random Forest JSAT does a little worse for some reason, and I can't find a setting in Weka to change. For Logistic Regression JSAT does a little better. However, in the case of Random Forest, you could trade some of your speed savings for a few more trees to make up the difference.

For SVMs, the speed advantage is about 2x for training and 4x for prediction. Speed advantages for the SVM are particularly important since the SVM is very sensitive to parameters, so a grid search is going to be needed, multiplying the runtime by a factor of 10x-100x depending on how many parameter combinations you want to test. The alternative algorithms in JSAT (Stochastic Kernel training directly & approximate feature space) up the advantage even further.

For the tree based, JSAT's advantage isn't as big. But if you are going to use trees in an ensemble (As is often the case), the 2.6x speedup is going to add up.

For Nearest Neighbor algorithms, both Weka and JSAT have data structures for accelerating nearest neighbor queries. While both of the ones from JSAT improved time and were fast to train, the Ball Tree algorithm in Weka was slower than the naive approach - and the Cover tree was only a little faster, but took almost as much time to train as JSAT did to do the whole problem to begin with!

For Logistic Regression, JSAT is again single digits faster, 3.6x. While LBFGS is a great and versatile tool to have, its not always the best choice especially for a bread-and-butter algorithm like Logistic Regression. JSAT's exact alternative exact solver (DCD) is the same algorithm used in LIBLINEAR, and is over 12x faster. However you don't always need an exact solution, and the SGD based solver in JSAT gets a solution almost as good as Weka's LBFGS and gets it 313 times faster.

Finally, for k-means, JSAT is already 24 times faster using the exact same algorithm. But JSAT also has two additional algorithms that obtain the exact same solution, but avoid redundant work - resulting in 95x to almost 200x faster training. With k-means being such a common tool both in use and as a building block for other algorithms, the difference is huge.

So, overall - JSAT has faster implementations of the same algorithms, and more algorithms - giving you the flexibility to pick the right tool for the job.

Just to be clear, I'm not claiming that JSAT has the fastest implementations ever of these algorithms. Many great tools, such as scikit-learn, have faster implementations for some algorithms. But JSAT does have a greater variety and if you are going to be working in Java, you should definitely consider JSAT over Weka just on the issue of speed.

No comments:

Post a Comment