Saturday, July 11, 2015

Easier Hyperparameter Tuning

I've just released the newest version of JSAT to my maven repo and github, and the biggest change is some work I've done to make parameter tuning much easier for people who are new to Machine Learning or JSAT specifically.

The crux of this new code, from the user's perspective, is a new method on the GridSearch object autoAddParameters. This method takes a DataSet object, and will automatically add hyper parameters with values to be tested. These values are adjusted to the dataset given, that way reasonable results can be given even if the dataset isn't scaled to a reasonable range.

The goal if this is to help new users and people new to ML. The issue of tuning an algorithm has been the most common feedback I receive from users, and part of the issue is that new users simply don't know what values are reasonable to try for the many algorithms in JSAT. Now with the code below, new users can get good results with any algorithm and dataset.


import java.io.File;
import java.io.IOException;
import java.util.List;
import jsat.classifiers.*;
import jsat.classifiers.svm.PlatSMO;
import jsat.classifiers.svm.SupportVectorLearner.CacheMode;
import jsat.distributions.kernels.RBFKernel;
import jsat.io.LIBSVMLoader;
import jsat.parameters.RandomSearch;

/**
 *
 * @author Edward Raff
 */
public class EasyParameterSearch
{
    public static void main(String[] args) throws IOException
    {
        ClassificationDataSet dataset = LIBSVMLoader.loadC(new File("diabetes.libsvm"));
        
        ///////First, the code someone new would use////////
        PlatSMO model = new PlatSMO(new RBFKernel());
        model.setCacheMode(CacheMode.FULL);//Small dataset, so we can do this
        
        ClassificationModelEvaluation cme = new ClassificationModelEvaluation(model, dataset);
        cme.evaluateCrossValidation(10);
        
        System.out.println("Error rate: " + cme.getErrorRate());
        
        /*
         * Now some easy code to tune the model. Because the parameter values
         * can be impacted by the dataset, we should split the data in to a train 
         * and test set to avoid overfitting. 
         */

        List<ClassificationDataSet> splits = dataset.randomSplit(0.75, 0.25);
        ClassificationDataSet train = splits.get(0);
        ClassificationDataSet test = splits.get(1);
        
        RandomSearch search = new RandomSearch((Classifier)model, 3);
        if(search.autoAddParameters(train) > 0)//this method adds parameters, and returns the number of parameters added
        {
            //that way we only do the search if there are any parameters to actually tune
            search.trainC(dataset);
            PlatSMO tunedModel = (PlatSMO) search.getTrainedClassifier();

            cme = new ClassificationModelEvaluation(tunedModel, train);
            cme.evaluateTestSet(test);
            System.out.println("Tuned Error rate: " + cme.getErrorRate());
        }
        else//otherwise we will just have to trust our original CV error rate
            System.out.println("This model doesn't seem to have any easy to tune parameters");
    }
}

The code above uses the new RandomSearch class, rather than GridSearch - but the code would look the same either way. RandomSearch is a good alternative compared to GridSearch, especially when more than 2 parameters are going to be searched over. Using the diabetes dataset, we end up with an output that looks like

Error rate: 0.3489583
Tuned Error rate: 0.265625

Yay, an improvement in the error rate! A potentially better error rate could be found by increasing the number of trials that RandomSearch performs. While the code that makes this possible is a little hairy, so far I'm happy with the way it works for the user. 

No comments:

Post a Comment