The crux of this new code, from the user's perspective, is a new method on the GridSearch object autoAddParameters. This method takes a DataSet object, and will automatically add hyper parameters with values to be tested. These values are adjusted to the dataset given, that way reasonable results can be given even if the dataset isn't scaled to a reasonable range.
The goal if this is to help new users and people new to ML. The issue of tuning an algorithm has been the most common feedback I receive from users, and part of the issue is that new users simply don't know what values are reasonable to try for the many algorithms in JSAT. Now with the code below, new users can get good results with any algorithm and dataset.
import java.io.File; import java.io.IOException; import java.util.List; import jsat.classifiers.*; import jsat.classifiers.svm.PlatSMO; import jsat.classifiers.svm.SupportVectorLearner.CacheMode; import jsat.distributions.kernels.RBFKernel; import jsat.io.LIBSVMLoader; import jsat.parameters.RandomSearch; /** * * @author Edward Raff */ public class EasyParameterSearch { public static void main(String[] args) throws IOException { ClassificationDataSet dataset = LIBSVMLoader.loadC(new File("diabetes.libsvm")); ///////First, the code someone new would use//////// PlatSMO model = new PlatSMO(new RBFKernel()); model.setCacheMode(CacheMode.FULL);//Small dataset, so we can do this ClassificationModelEvaluation cme = new ClassificationModelEvaluation(model, dataset); cme.evaluateCrossValidation(10); System.out.println("Error rate: " + cme.getErrorRate()); /* * Now some easy code to tune the model. Because the parameter values * can be impacted by the dataset, we should split the data in to a train * and test set to avoid overfitting. */ List<ClassificationDataSet> splits = dataset.randomSplit(0.75, 0.25); ClassificationDataSet train = splits.get(0); ClassificationDataSet test = splits.get(1); RandomSearch search = new RandomSearch((Classifier)model, 3); if(search.autoAddParameters(train) > 0)//this method adds parameters, and returns the number of parameters added { //that way we only do the search if there are any parameters to actually tune search.trainC(dataset); PlatSMO tunedModel = (PlatSMO) search.getTrainedClassifier(); cme = new ClassificationModelEvaluation(tunedModel, train); cme.evaluateTestSet(test); System.out.println("Tuned Error rate: " + cme.getErrorRate()); } else//otherwise we will just have to trust our original CV error rate System.out.println("This model doesn't seem to have any easy to tune parameters"); } }
The code above uses the new RandomSearch class, rather than GridSearch - but the code would look the same either way. RandomSearch is a good alternative compared to GridSearch, especially when more than 2 parameters are going to be searched over. Using the diabetes dataset, we end up with an output that looks like
Error rate: 0.3489583
Tuned Error rate: 0.265625
Yay, an improvement in the error rate! A potentially better error rate could be found by increasing the number of trials that RandomSearch performs. While the code that makes this possible is a little hairy, so far I'm happy with the way it works for the user.
No comments:
Post a Comment