Sunday, June 23, 2013

Cosine Similarity Locality Sensitive Hashing

I have been meaning to try implementing and learning more about Locality Sensitive Hashing (LSH) for a while now. The general goal of LSH is to have a hash function where you want collisions, and similar items will colide into the same bucket. That way, you can find similar points by just hashing into a bucket, instead of searching the whole data set.

Form my prior perusing, I found there is a lot of inconsistency in the definition of LSH between papers, and how they define the objective. So I started out with two of the earlier algorithms - one called E2LSH and one based on Random Projections (RP) [in revisions r723 and r725].

The Random Projections one is particularly interesting because, on its own, it actually isn't hashing similar items into the same bucket (at least, not often). It creates a signature and then you can do LSH again on the signature for the hamming distance (which is equivalent to the L1 distance since it is 0 and 1 values only). But if you dont want to do this tiered hashing, RP LSH can still be useful for doing brute force NN in a fast way with approximated values. This is because the number of different bits in the signature is related to the cosine similarity of the original vectors. So I decided to try it out on the 20 News Group data set.

I wrote the code for a simple loader that you can see below. Its very basic (and the file path should be a parameter) but I was being lazy. Its common practice to avoid the header information for the news group data set (so that your algoritm dose not associate the people with the news group, but the content). To do this I used a quick hack of just finding the "Lines:" line of the header. I ripped this code out of a bigger file with a lot of other testing / experiment code, so I didn't bother with the imports - but your IDE should be able to resolve them for you.

/*
 * A simple example on how to load a text data set. The 20 News Groups data set is
 * designed such that each folder is a different class, and each file in the folder 
 * contains a document. 
 */

public class NewsGroupsLoader extends ClassificationTextDataLoader
{
    
    private File mainDirectory = new File("<your path to the data set>/20_newsgroup/");
    private List<File> dirs;

    public NewsGroupsLoader(Tokenizer tokenizer, WordWeighting weighting)
    {
        super(tokenizer, weighting);
    }

    @Override
    protected void setLabelInfo()
    {
        File[] list = mainDirectory.listFiles();
        dirs = new ArrayList<File>();
        for(File file : list)
            if(file.isDirectory())
                dirs.add(file);
        
        labelInfo = new CategoricalData(dirs.size());
        labelInfo.setCategoryName("News Group");
        for(int i = 0; i < dirs.size(); i++)
            labelInfo.setOptionName(dirs.get(i).getName(), i);
    }

    @Override
    public void initialLoad()
    {
        for(int i = 0; i < dirs.size(); i++)
        {
            for(File doc : dirs.get(i).listFiles())
            {
                try
                {
                    String line;
                    StringBuilder sb = new StringBuilder();
                    BufferedReader br = new BufferedReader(new FileReader(doc));
                    
                    boolean startConsidering = false;
                    while((line = br.readLine()) != null)
                    {
                        if(!startConsidering)
                        {
                            //the format is such that most of the header occurs 
                            //before the "lines:" line, so this way we can skip 
                            //most of the HTML header info without a lot of work. 
                            if(line.trim().startsWith("Lines:"))
                                startConsidering = true;
                            continue;
                        }
                        sb.append(line).append(" ");
                    }
                    
                    br.close();
                    
                    addOriginalDocument(sb.toString(), i);
                    
                }
                catch (FileNotFoundException ex)
                {
                    ex.printStackTrace();
                }
                catch(IOException ex)
                {
                    ex.printStackTrace();
                }
            }
        }
    }
    
}



I then just did 10 fold cross validation on 7 nearest neighbors (chosen arbitrarily). I've graphed the results below for the Error rate, Training time, and Query time summed over the 10 folds. There are 19997 data points and 14572 features after I removed all words that occurred less than 10 times.

For the RP LSH, the naive set up requires building a large matrix in memory of random values, which takes up a lot of space! For doubles, using a 512 bit signature means 512 rows of a matrix with 14572 columns and 8 bytes per double brings up a total of just under 60 megabytes. For a 4096 bit signature, that means 460 megabytes. On my MacBook Air, that was the biggest I could do before paging started to occur with everything else I was running.

However, you dont have to explicitly keep this matrix in memory. One way is to generate every random value as needed. In JSAT I have a RandomMatrix and RandomVector classes explicitly for making this kind of code easier. Unfortunately generating random gaussian values is expensive, making that is very slow at runtime unless your data is incredibly sparse - or you only need the values infrequently. The code itself is also a little on the slow side unfortunately, which is something I'm still investigating (if you know of a good and fast hash function for an input of 3 integers, let me know!).

The other alternative is to use a technique called pooling. Essentially you create a pre-defined pool of unit normal values, and then index into the pool when the matrix needs to return a value. This way we can create a small (less than one megabyte) cache of values that we use for all indices. This works surprisingly well for sparse data sets, and is fairly robust to poor random number generation, so I was able to exploit that to make the code a lot faster.

So below is the Error Rate, Training, and Query time for the naive true NN, an in Memory LSH, and two pools of 104 and 105 elements each (80 kb and 780 kb respectively).




As you can see, NN isn't the most accurate method for this data set (other algorithms only get down to around 0.2 in terms of error). However, as we increase the signature size, it gets more an more accurate - and by 8192 (which is a considerable signature), the accuracies are comparable. The pooling method isn't much worse than the in memory version of RP LSH, with the 105 one almost identical in accuracy but using 500 times less memory. This also allow us to do even larger signatures than the in memory method.

Looking at query time, we can see that even with these large signatures, the LSH query time is less than half of the normal brute force method, which is pretty good. While JSAT isn't designed for it, this would be particularly important if the original data set didn't fit into memory, the signatures probably would. We can also see that the pooled versions are slightly faster in terms of query time, mostly due to better locality (IE: we dont have over 400 megabytes of matrix eating all the cache). 

In training time, we see that the naive NN doesn't really have any, but LSH needs to construct the random matrix and then compute all the signatures. We also see the pooling method is a good deal faster than the in memory method, and the difference grows with the signature length. This is because the pooled version creates a fixed number of random gaussians independent of signature size, where the in memory one has to generate enough to fill the whole matrix. 

Overall, its pretty easy to see how the cosine LSH could be useful. I'm not sure if such large signature lengths are generally needed, or an artifact of this data set in particular - its something to look at in the future though. But even with these long signatures, it is faster then normal brute force in terms of query performance, and the trade off in accuracy and runtime by signature length could be exploited in a lot of ways.

Below is the code for trying the different LSH schemes  just uncomment the one you want (and get your IDE to fill in the imports). You'll need the latest subversion revision (r725) for it to work.
public class NewsGroupsMain
{
    public static void main(String[] args)
    {
        //create the tokenizer and stemmer
        Tokenizer token = new NaiveTokenizer();
        token = new StemmingTokenizer(new PorterStemmer(), token);
        //create thew data set loader
        NewsGroupsLoader ngl = new NewsGroupsLoader(token, new TfIdf());
        
        ClassificationDataSet cds = ngl.getDataSet();

        //remove all features that occur less than 10 times in the whole data set
        cds.applyTransform(ngl.getMinimumOccurrenceDTF(10).getTransform(cds));

        CategoricalData pred = cds.getPredicting();
        System.out.println("Loaded");

        System.out.println("Data Set Size: " + cds.getSampleSize());
        System.out.println("Features: " + cds.getNumNumericalVars());

        Classifier classifier;
        
        //NN the naive way (default collection used for high dimensional data is brute force)
        //classifier = new NearestNeighbour(7, false, new CosineDistance());
        //NN for the LSH with an in memory matrix
        //classifier = new NearestNeighbour(7, false, new CosineDistance(), new RandomProjectionLSH.RandomProjectionLSHFactory<VecPaired<Vec, Double>>(16, true));
        //NN for the LSH with a pool of 10^4 
        classifier = new NearestNeighbour(7, false, new CosineDistance(), new RandomProjectionLSH.RandomProjectionLSHFactory<VecPaired<Vec, Double>>(16, 1000));
        
        ClassificationModelEvaluation cme = new ClassificationModelEvaluation(classifier, cds);
        cme.evaluateCrossValidation(10);
        System.out.println("Error Rate: " + cme.getErrorRate());
        
        System.out.println("Training Time: " + cme.getTotalTrainingTime());
        System.out.println("Testing Time: " + cme.getTotalClassificationTime());
        
        cme.prettyPrintConfusionMatrix();
    }
}

No comments:

Post a Comment