Monday, July 22, 2013

L1 Regularization with Superfluous Features

I've been working on adding some more Linear algorithms again, with a focus on making sure there was some L1 regularized learners in JSAT. This was the one area JSAT was really lacking in, and there really are not a lot of great algorithms for it. A number of the published papers I attempted implementing didn't even work well outside of a very narrow range.

The two algorithms I did implement that worked well were Binary Bayesian  Regression (BBR) and Sparse Truncated Gradient Descent (STGD). The former is a batch Logistic Regression algorithm and the latter an online algorithm for the squared loss. Both of which support L1 regularization and worked well for me once implemented. 

I justed used BBR to create a fun graph. Consider a data set where all features are drawn from independent N(0, 1). Let only 5 of these features be relevant, and restrict yourself to only 50 data points. How do you find the relevant features as you keep adding more irrelevant dimensions? Unregularized learners will quickly degrade to random guessing or worse. L2 regularization is easy enough to do quickly, but still isn't strong enough to learn the right coefficients as the dimensionality increases. 

This is where L1 regularization comes into play. While most papers mention the sparsity property of L1, the real power is that L1 gives us theoretical bounds on its performance in the face of irrelevant features. This is incredibly important as we collect more and more data with lots of features, where we don't really know which of the features are useful for our decision. 



Its a fun plot to look at, and you can clearly see the L1 prior is helping maintain a reasonable level of performance. This graph is slightly unique in how I made the problem. Most L1 vs L2 graphs like this show the L1 prior doing much better, staying near the original accuracy. Often they construct the irrelevant features as sparse, with a binary 0/1 value with a small probability of being 1. The performance is still good, but its important to emphasis that the L1 prior isn't impervious to random features.

Below is the code to generate the same data, including some more that I didn't include. You can try it out to see how the regularization effects the problem, and even change the problem to the sparse 0/1 version and see how it holds up. Note: that BBR takes a while to converge for very small regularization values (\(\lambda \leq 10^{-4}\) for me) for the L1 prior.


import java.text.DecimalFormat;
import java.util.Random;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.linear.BBR;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.utils.random.XORWOW;
import jsat.regression.LogisticRegression;

/**
 *
 * @author Edward Raff
 */
public class L1L2Comparison
{

    /**
     * @param args the command line arguments
     */
    public static void main(String[] args)
    {
        
        int relevant = 5;
        int classSize = 50;
        int[] irrelevant = new int[]
        {
            0, 5, 15, 45, 95, 245, 495, 
        };
        
        double[] regPenalties = new double[]
        {
            1e-4, 0.001, 0.01, 0.1, 0.5, 1.0, 5.0, 10.0
        };
        
        Random rand = new XORWOW();
        
        double[] coef = new double[relevant];
        
        for(int i = 0; i < relevant; i++)
            coef[i] = rand.nextDouble()*10;
        
        DecimalFormat df = new DecimalFormat("#.#########");
        
        for(int i = 0; i < irrelevant.length; i++)
        {
            int D = irrelevant[i]+relevant;
            ClassificationDataSet cds = new ClassificationDataSet(D, new CategoricalData[0], new CategoricalData(2));
            for(int k = 0; k < classSize; k++)
            {
                Vec xP = new DenseVector(D);
                for(int j = 0; j < D; j++)
                    xP.set(j, rand.nextGaussian());
                double result = 0;
                for(int j = 0; j < relevant; j++)
                    result += coef[j]*xP.get(j);
                if(result > 0)
                    cds.addDataPoint(xP, new int[0], 1);
                else
                    cds.addDataPoint(xP, new int[0], 0);
            }
            
            System.out.println("\n\nD: " + D);
            LogisticRegression lr  = new LogisticRegression();
            
            ClassificationModelEvaluation cmeLr = new ClassificationModelEvaluation(lr, cds);
            cmeLr.evaluateCrossValidation(10);
            
            System.out.println("UNIFORM: " + df.format(cmeLr.getErrorRate()));
            
            System.out.print("REG: ");
            for(double reg : regPenalties)
                System.out.print(df.format(reg) + ", ");
            System.out.print("\n L2: ");
            for(double reg : regPenalties)
            {
                BBR bbr = new BBR(reg, 1000, BBR.Prior.GAUSSIAN);
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(bbr, cds);
                cme.evaluateCrossValidation(10);
                System.out.print(df.format(cme.getErrorRate()) + ", ");
            }
            System.out.print("\n L1: ");
            for(double reg : regPenalties)
            {
                BBR bbr = new BBR(reg, 1000, BBR.Prior.LAPLACE);
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(bbr, cds);
                cme.evaluateCrossValidation(10);
                System.out.print(df.format(cme.getErrorRate()) + ", ");
            }
            
        }
    }
}