Friday, November 20, 2015

A Binary Format for JSAT Datasets

I made a post a while ago about improving the LIBSVM file loader in JSAT so that it wouldn't use nearly as much memory and be a good deal faster too, and I complained about everyone using human readable ASCII file formats. Spurred by a recent pull request, I've finally gone ahead and implemented a simple binary format for storing datasets in JSAT. I'm not using Java's serialization for this, so it's a binary format that one could just as easily implement a reader/writer for in other languages as well.

The binary format supports both sparse and dense storage of numeric features, and stores the string names for categorical features. Since floating point values take up the majority of space, it also supports saving them in multiple different methods. Currently it can save values as a 32 or 64 bit float, as a short, or as a signed/unsigned byte. The default method is to scan through the dataset and check which of the options would result in the smallest file without losing any information. Despite this overhead, it's faster than writing either an ARFF or LIBSVM file! You can also explicitly pass the method you want to store it as, which will skip the overhead and do a lossy conversion if necessary.

I did a simple performance case for reading/writing the training set of MNIST. For the JSATData I tested with sparse and dense numeric features (determined by how the data is stored in memory) and using the Auto/ Unsigned byte, and 64 bit float options. 

For this first table, I left the data normalized, so it was stored as integers from 0 to 255, making it easy for it to be saved as bytes. The JSAT writer for ARFF and LIBSVM writes out 0s as "0.0", so technically they have a some unnecessary padding. These numbers are from my Macbook Air which has a nice SSD.

Method ARFF LIBSVM JSATData FP64 (sparse) JSATData U_BYTE (sparse) JSATData AUTO (sparse) JSATData FP64 JSATData U_BYTE JSATData AUTO
Read Time (ms) 7810 3777 989 758 2839 1735
Write Time (ms) 7091 3322 790 586 1652 1894 895 1580
File Size (MB) 203.7 87.5 108.9 45.6 377.1 47.4

In this next table, I normalized the values to a range of [0, 1]. This makes the JSAT code AUTO select FP64, and uses a lot more text in ARFF and LIBSVM. Since U_BYTE won't work any more, I also did a force as FP32.
Method ARFF LIBSVM JSATData FP64 (sparse) JSATData AUTO (sparse) JSATDATA FP32 (sparse) JSATData FP64 JSATData AUTO JSATData FP32
Read Time (ms) 15247 6920 1032 941 2873 2765
Write Time (ms) 9445 5643 732 808 833 1794 1788 1933
File Size (MB) 318.3 202.1 108.9 72.7 377.1 188.7

You'll notice that the differences in write time didn't change so much for the second table between 64 bit and AUTO. This is because the code for auto detecting the best format will quit early once it has eliminated everything more efficient than a 64 bit float. As promised, the 64 bit format doesn't change in file size at all, which is a much more consistent and desirable behavior. And even when JSATData does not result in smaller file sizes, the format is simple making it much more IO bound, so it's much faster than the CPU bound ARFF and LIBSVM which have to do a bunch of string processing and math to convert the strings to floats.

I've also added a small feature for strings stored in the format, since I save out the names of categorical features and their options. There is a simple marker to indicate if strings are ASCII or UTF-16, that way for common ASCII strings not as much data is wasted. The writer will also auto-detect if ASCII is safe or it needs UTF-16.

I've written this format with JSAT's three main dataset types in mind, but hopefully this can be useful for others as well. If there is interest I may write a reader/writer for Python and C/C++ and host them up as small little projects on github. 

Monday, October 26, 2015

Visualization Algorithms in JSAT

For the next update of JSAT (you can try them early in the 0.0.3-SNAPSHOT release), I've been working on implementing a few algorithms meant specifically for visualizing datasets. For now I've chosen three of the more common / well known algorithms, MDS, Isomap, and t-SNE. Unlike PCA, these algorithms are mostly meant to transform only one dataset at once - the intention being to visualize that block of data. You can't necessarily apply these algorithms to new data points, nor would you necessarily want to. Their purpose is really to help explore data and learn more.

A classic test problem is the "Swiss roll" dataset. This is a dataset where some of the data, such as the red and blue points below, are close based on the euclidean distance - but our intuition about the data is that they are really far away from each other in actuality. So a good 2D visualization of this data would place those data points far away from each other while keeping the red/orange close together.

t-SNE doesn't do a great job at visualizing the swiss roll data, but for more general datasets t-SNE is a great algorithm. The implementation in JSAT is the faster Barnes Hut approximation, so it runs in O(n log n) time, making it applicable to larger datasets. Below is a visualization of t-SNE on the MNIST dataset, which makes it easy to see 10 clusters, 1 for each datapoint. Looking at the relationships we can see stuff that makes sense. The 9 and 4 clusters are near each other, which makes sense - the noisy 9s and 4s could easily be misconstrued. We also see that some noisy 9s in the cluster of 7s, another understandable case.

t-SNE was a bit of work to implement, and I had difficulty replicating the results using the exact details in the paper. A few changes worked well for me, and had the benefit of being a bit easier to code. Right now my tSNE implementation is a slower single-threaded then the standard implementation, but mine does support multi-threaded execution.

Hopefully these new tools will be useful for people trying to get a better intuition of what's happening in their data. All three algorithms support multi-threaded execution and implement a new VisualizationTransform interface. The Isomap implementation also includes an extension call C-Isomap, and I'm hoping to add more useful visualization algorithms in the future.

On another note, I've moved the GUI components outside of JSAT completely. Anything GUI related will now be in a new project, JSATFX. As you might guess from the name, I'm going to be trying to do more of the code in JavaFX so I can learn it. The visualizations above are from the library.

Saturday, July 11, 2015

Easier Hyperparameter Tuning

I've just released the newest version of JSAT to my maven repo and github, and the biggest change is some work I've done to make parameter tuning much easier for people who are new to Machine Learning or JSAT specifically.

The crux of this new code, from the user's perspective, is a new method on the GridSearch object autoAddParameters. This method takes a DataSet object, and will automatically add hyper parameters with values to be tested. These values are adjusted to the dataset given, that way reasonable results can be given even if the dataset isn't scaled to a reasonable range.

The goal if this is to help new users and people new to ML. The issue of tuning an algorithm has been the most common feedback I receive from users, and part of the issue is that new users simply don't know what values are reasonable to try for the many algorithms in JSAT. Now with the code below, new users can get good results with any algorithm and dataset.

import java.util.List;
import jsat.classifiers.*;
import jsat.classifiers.svm.PlatSMO;
import jsat.classifiers.svm.SupportVectorLearner.CacheMode;
import jsat.distributions.kernels.RBFKernel;
import jsat.parameters.RandomSearch;

 * @author Edward Raff
public class EasyParameterSearch
    public static void main(String[] args) throws IOException
        ClassificationDataSet dataset = LIBSVMLoader.loadC(new File("diabetes.libsvm"));
        ///////First, the code someone new would use////////
        PlatSMO model = new PlatSMO(new RBFKernel());
        model.setCacheMode(CacheMode.FULL);//Small dataset, so we can do this
        ClassificationModelEvaluation cme = new ClassificationModelEvaluation(model, dataset);
        System.out.println("Error rate: " + cme.getErrorRate());
         * Now some easy code to tune the model. Because the parameter values
         * can be impacted by the dataset, we should split the data in to a train 
         * and test set to avoid overfitting. 

        List<ClassificationDataSet> splits = dataset.randomSplit(0.75, 0.25);
        ClassificationDataSet train = splits.get(0);
        ClassificationDataSet test = splits.get(1);
        RandomSearch search = new RandomSearch((Classifier)model, 3);
        if(search.autoAddParameters(train) > 0)//this method adds parameters, and returns the number of parameters added
            //that way we only do the search if there are any parameters to actually tune
            PlatSMO tunedModel = (PlatSMO) search.getTrainedClassifier();

            cme = new ClassificationModelEvaluation(tunedModel, train);
            System.out.println("Tuned Error rate: " + cme.getErrorRate());
        else//otherwise we will just have to trust our original CV error rate
            System.out.println("This model doesn't seem to have any easy to tune parameters");

The code above uses the new RandomSearch class, rather than GridSearch - but the code would look the same either way. RandomSearch is a good alternative compared to GridSearch, especially when more than 2 parameters are going to be searched over. Using the diabetes dataset, we end up with an output that looks like

Error rate: 0.3489583
Tuned Error rate: 0.265625

Yay, an improvement in the error rate! A potentially better error rate could be found by increasing the number of trials that RandomSearch performs. While the code that makes this possible is a little hairy, so far I'm happy with the way it works for the user. 

Monday, June 1, 2015

Exact Models and Cross Validation from Warm Starts

One of the most common tasks when performing model building is to do some Cross Validation. Computationally speaking, this can be a very performance intensive deal. If you want to do 10-fold cross validation, you're basically going to be waiting ~10x as long for your model to build.

However, there is a potential speed up available for Cross Validation but only when using models that converge to an exact solution. First you train 1 model on all the data, and then train the 10 CV models using the first model as a warm start. This way we are training 11 models, but the 10 CV models will train significantly faster - and should be enough of a difference to offset the training time.

The use of models that converge to a singular exact solution is critical for this to be a sensible idea. The point of the CV is to keep data separate, and evaluate on data you haven't seen/trained on. By warm starting from a model trained on everything, the CV models have kinda seen all the data before. But because the model converges to an exact solution, it would be the same solution regardless of the warm start. The model basically forgets about the earlier data.

If we had used a model that does not converge to an exact solution, or has multiple local minima - this approach would bias the CV error. In the former case, we may not move all the way from the first solution we found. Being closer to the initial solution means we may have enough bias left to "remember" some of what we saw. This is even worse in the later case - the initial solution on all the data may put you in a local minima that you wouldn't have been in if you hadn't warm started, which basically guarantees you won't be able to get too far away.

Hopefully I'll get some time to experiment with actually implementing this idea.

Sunday, March 29, 2015

JSAT vs Weka on MNIST

I'm a strong proponent of using the right tool for the job, and I need to prefix this post with that fact that Java is not the best tool for Machine Learning in general. But you may know java best, you may need to interface with a Java infrastructure, or you might just know / prefer Java over other languages. For any number of reasons you may need or want to do some ML work in Java, and Weka is the choice that most people hear first.

My problem, is that most of Weka is inexcusably slow - and doesn't provide enough algorithms for you to really chose the right tool for your job. The code base is massive for the few algorithms it does support, and its bloated with code duplication and inefficient design. The biggest time suck in ML is usually model building / testing, and so I've put together a small benchmark on MNIST.

MNIST, as a data set, is often overused. However it its just large enough that if your code is inefficient you will feel it. We also know very well what accuracies are obtainable with various models on MNIST, making it a good sanity check. So I downloaded an ARFF version of the data from here and made this benchmark (hosted here). The purpose of this benchmark is to show that JSAT has significantly faster implementations of many of the same algorithms, and has better algorithm to use as well.

For this benchmark I've tried to keep everything apples-to-apples, and did my best to make sure each algorithm was doing the same thing and had the same options set. Most of the cases I made JSAT use the parameters Weka uses by default when possible. For some algorithms (like SVMs) that need a good set of parameters, I used values from a grid search I did before with LIBSVM.

All the times below are presented in seconds, run on my iMac (2.66 Ghz Core i5, 16GB of RAM) with nothing else. I've omitted the train/test time for algorithms where training/testing is not a performance issue.

Algorithm Weka Time Weka Error JSAT Time JSAT Error JSAT Speedup
SVM w/ RBF Kernel (Full Cache)
Train: 7713.549 
Test: 1339.633 
Train: 3661.7
Test: 337.765
Train: 2.1x
Test: 4.0x
SVM w/ RBF Kernel (No Cache)
Train: 5657.924 
Test: 1336.557
Train: 2558.846
Test: 317.663
Train: 2.2x
Test: 4.2x
RBF SVM stochastic w/ 5 iterations
Train: 518.654
Test: 10.443
Train: 10.9x
Test: 128.0x
(Over SVM)
RBF SVM RKS features w/ Linear Solver
Train: 68.398
Test: 0.571
Train: 82.7x
Test: 2340x
(Over SVM)
C4.5 Decision Tree Train: 303.373 0.1134 Train: 117.785 0.1146 Train: 2.6x
Random Forest w/ 50 trees Tain: 143.127 0.0326 Train: 100.673 0.0453 Train: 1.4x
1-NN (brute force) Test: 2537.483 0.0309 Test: 648.71 0.0309 Test: 3.9x
1-NN (Ball Tree)
Train: 52.263
Test: 3269.183  
1-NN (Cover Tree)
Train: 538.132
Test: 2245.709 
1-NN (VPmv)
Train: 1.909
Test: 493.69
1-NN (Random Ball Cover)
Train: 13.448 
Test: 576.737
Logistic Regression by LBFGS, λ = 1e-4 Train: 3301.899 0.0821 Train: 907.259 0.0776 Train: 3.6x
Log Regression stochastic w/ 10 iterations Train: 10.545 0.0840
Train: 313x
(over LBFGS)
Logistic Regression OneVsAll DCD Train: 276.865 0.080
Train: 12.3x
(over LBFGS)
(lloyd's algorithm)
1010.6009 41.1913 24.5x
(Hamerly's algorithm)
10.5358 95.9x
(Elkan's algorithm)

When running the same algorithms, most results end up with about the exact same accuracy. For the Random Forest JSAT does a little worse for some reason, and I can't find a setting in Weka to change. For Logistic Regression JSAT does a little better. However, in the case of Random Forest, you could trade some of your speed savings for a few more trees to make up the difference.

For SVMs, the speed advantage is about 2x for training and 4x for prediction. Speed advantages for the SVM are particularly important since the SVM is very sensitive to parameters, so a grid search is going to be needed, multiplying the runtime by a factor of 10x-100x depending on how many parameter combinations you want to test. The alternative algorithms in JSAT (Stochastic Kernel training directly & approximate feature space) up the advantage even further.

For the tree based, JSAT's advantage isn't as big. But if you are going to use trees in an ensemble (As is often the case), the 2.6x speedup is going to add up.

For Nearest Neighbor algorithms, both Weka and JSAT have data structures for accelerating nearest neighbor queries. While both of the ones from JSAT improved time and were fast to train, the Ball Tree algorithm in Weka was slower than the naive approach - and the Cover tree was only a little faster, but took almost as much time to train as JSAT did to do the whole problem to begin with!

For Logistic Regression, JSAT is again single digits faster, 3.6x. While LBFGS is a great and versatile tool to have, its not always the best choice especially for a bread-and-butter algorithm like Logistic Regression. JSAT's exact alternative exact solver (DCD) is the same algorithm used in LIBLINEAR, and is over 12x faster. However you don't always need an exact solution, and the SGD based solver in JSAT gets a solution almost as good as Weka's LBFGS and gets it 313 times faster.

Finally, for k-means, JSAT is already 24 times faster using the exact same algorithm. But JSAT also has two additional algorithms that obtain the exact same solution, but avoid redundant work - resulting in 95x to almost 200x faster training. With k-means being such a common tool both in use and as a building block for other algorithms, the difference is huge.

So, overall - JSAT has faster implementations of the same algorithms, and more algorithms - giving you the flexibility to pick the right tool for the job.

Just to be clear, I'm not claiming that JSAT has the fastest implementations ever of these algorithms. Many great tools, such as scikit-learn, have faster implementations for some algorithms. But JSAT does have a greater variety and if you are going to be working in Java, you should definitely consider JSAT over Weka just on the issue of speed.

Monday, March 16, 2015

Improving dataset loading in JSAT/Java

I've had a loader for the LIBSVM format of storing data sets for a while now. I recently ran into a rather fun issue, where in loading a rather large data set I got an error. I didn't read the error, but the first thing I did was run it again with the debugger on - waiting for it to break on wherever the error was. And it ran through perfectly!

Turns out I was getting a GC Overhead limit exception. This happens when the JVM spends too much time doing GC work, and just kills your job. By running in debug mode, I slowed down how quickly new garbage was created, which gave the GC enough extra head room to make it through.

While amusing, I did need to fix this - and I encountered something surprising. There doesn't seem to be any good ware to turn a subset of a character sequence into a floating point value. The only method I could find was to run Double.parseDouble() on a sub string, but that meant creating a new garbage object! So I had to write that, and added code to read in the into a buffer, and walk through the buffer for the values I needed, re-filling when necessary.

Here is a graph of the memory use of loading a file with the original LIBSVM code. The first graph just shows memory use. The second shows GC overhead and the number of collections.

Not horrible, but this was on a small data set. On larger ones (like the 100GB file I encountered the error with) this is a problem. Below are the same graphs for the new code. 

That looks much better! It also looks more like what some people would expect had the code be written in C/C++ instead of Java, what was one of the reasons I wanted to show it. The GC in Java is a great tool, but that doesn't mean we should always rely on it. When necessary, there isn't much stopping us from writing Java code that looks like C/C++ and behaves very similarly. 

Current'y I'm not super happy with the new code, as I wrote it in a hurry at home so I could used the updated code at work. If I go back and re-write it again, I'll probably make a small finite state machine so that the logic is easier to follow. But it does perform much better! 

I also had to implement my own parsing of a character sequence to a double to get this performance, and this has the unfortunate side effect that you don't necessarily get the exact same double value you would have gotten if you used the standard library. This is something I need to fix, and an artifact of how surprisingly involved dealing with floating point always is. However the current code's relative error is always less than 10-14, which means the ML code will all be fine. But its still not a nice surprise. 

I also want to take a moment and complain about ASCII data sets. For most machine learning data sets, a common binary representation would take of less space, load faster, and be easier to handle. I don't understand everyones fascination with making things JSON or XML or whatever other text format they are using. Its almost always unneeded and added complexity for the sole benefit of being "human readable", even if a human isn't ever going to read it!

Wednesday, February 25, 2015

ML Theory, some Common Questions

Prompted by a few recent exchanges, I thought I would take a stab at pre-answering a few common questions I see/get about some results in Machine Learning, and how they relate to applying. You can find these all elsewhere as well.

Q: A Neural Network with a single hidden layer can approximate any function arbitrarily well, so why have deep learning / multiple layers?  Why don't Neural Networks solve everything? 

A: First, I've intentionally phrase this one in the way I hear it - because its not quite true. It can only approximate functions that are continuous arbitrarily well. To get discontinuous functions in the mix, we need a second hidden layer. But that's not terribly important.

What this result doesn't tell us is how many neurons to use, and the one of the benefits of having multiple layers is that it may take less neurons overall to represent something with a hierarchy than to try and represent it with one colossal hidden layer. That is to say, the power of a network to represent things is not linear in the number of neurons and hidden layers. Increasing the number of hidden layers can significantly simplify the problem / complexity of the network.

The other issue is this result is only for predicting the response in the region the network was trained on. It tells us nothing about how the network might perform when given an outlier as input. This subtly leads us to the issue that this statement doesn't tell us how the network will generalize to new data. If we have concept drift, or not enough data to fully encompass the problem, this result does not help us at all - as we won't be able to generalize to new data.

Q: Algorithm X converges to the Bayes optimal error rate (times some \(\epsilon\) ), so I just need enough data and I will get the optimal classifier. 

A: This is similar to the first question, and is a result that is true for doing a simple Nearest Neighbor search, but instead of function approximation the result is explicitly for classification. As one may have guessed, this doesn't tell us anything about how much data is needed to get to the Bayes error rate. Even more sadly: this doesn't tell us how close we are to reaching the Bayes error rate / when we get there - just that eventually we will.

A more subtle issue comes from the use of a classification problem instead of a regression / function approximation problem. When we talk about function approximation, we generally assume that we know all the inputs to the function (thought this isn't necessarily true). When we set up classification problems in practice we generally don't know for sure what the needed features are. This is important because the Bayes optimal error rate given is dependent on the features given ie: different sets of features have different Bayes optimal error rates. It doesn't tell us that we will get to the best possible error rate for the problem, but for the features given the problem.

Q: Logistic Regression and Support Vector Machines all use a loss function, but why don't we use the zero-one loss? Phrased another way: why don't we learn to minimize the classification errors, rather than some other loss function. 

A: The answer to this is simply we don't know how! Its also an interesting question if this would be the best thing to do (ie: will it generalize to new data?). SVMs and Logistic Regression are large-margin classifiers (where the form is the maximum-margin), and its often been found that large-margin models work better than those that don't have a large margin. However a large-margin loss incurs can incur a penalty on data that they classify correctly, where the zero-one loss wouldn't. So that result should give us pause to think about the real problem: what should we optimize to maximize our generalized error? Again, we don't really know!

Q: I don't understand why high-dimensional data is more difficult. What's the intuition to the curse-of-dimensionality being a problem, isn't more better?

A: For explaining the intuition, I like to think of it like this: For every feature/input, we have X many weights to learn. So lets assume that each data point can contribute to learning only one of the weights correctly. If we have just as many data points as we have weights, and the data is perfectly clean, then we can solve the problem (if the problem is full rank). However data is never perfect and functions can be noisy / overlap, so we want multiple data points for each feature. That way we can get a better idea of a weight value that works best for most of the features.

When we have high-dimensional problems, \(D >> N\), that means we can't even get a full rank solution! Thats the basis for an intuition of why high dimensional. The way I've explained this is that we need a linear relationship between our number of features \(D\) and the number of data points \(N\), where \(N > D\). The truth is that we may need exponentially more data points as we increase the dimension of the problem.

A simple analogy: If I gave you a budget of $500 (your data points) and 10 products (your features) to evaluate, you could easily buy all the products, test them, and tell me which are more or less important (the importance being your weights). But if I told you to evaluate 10,000 products, and I upped your budget to $500,000 - you wouldn't really be able to just evaluate all the products any more. You could hire others to help you, but then you wouldn't have enough money to buy all the products. And as you hire more people someone has to manage them, which is even more of your budget - your overheads are increasing! Obviously don't take this too literally, but hopefully it connects how just increasing one thing (the features) can have a bigger impact on what you need (data points) to solve the problem.

Q: I don't understand why we have different algorithms for classification/regression, aren't they all doing the same thing? Why isn't there a/ what is the 'best' algorithm?

A: There actually exists a proof that there can be no one best algorithm. The fact of the matter is all of these algorithms are models trying to approximate the world, and each model will make various different assumptions about the world. So we try to chose the model which best matches our understand of the world, or at least the one that performs best if we don't know how to world really works!

Often there exists models that make very few assumptions about the world - like nearest neighbor / decision trees. And as a consequence of that, we often see them perform very well on a variated of problems. But even these have some implicit assumptions.

A normal decision implicitly tree assumes that separations occur on axises. This means we will require a tree of infinite depth to "truly" model anything that doesn't fit on an axis. But we can often approximate them well enough with this, and so it works. But if we know that there is a lot of curvature / bending in the data, then it may be worthwhile using a method that assumes that behavior is part of the world, as it will probably have an easier time learning. This other model would then be better for the problem.

That doesn't mean decision trees would do badly, or that there aren't any number of models that would get equally good results. It just means we have to pick and choose what models/tools we use for the various problems we encounter.