Sunday, June 29, 2014

Stochastic LDA

Not too long ago I finally implemented an algorithm for Latent Dirichlet allocation, which has become one of the most popular topic modeling techniques. While most implementations you will find online are some form of Gibbs sampling, a stochastic version was what interested me. Specifically the algorithm described in Hoffman, M., Blei, D., & Bach, F. (2010). Online Learning for Latent Dirichlet Allocation. In Advances in Neural Information Processing Systems (pp. 856–864). I had tried implementing this a few times before with little success, until I found some extra notes on performing the update in a sparse manner here

Since LDA is often going to be used on unlabeled data, which is very abundant - I wanted an implementation that would work for extremely large corpa. Using a stochastic algorithm, the new data points can be streamed in from disk as needed - where as a normal Gibbs sampler requires all of the data to be in memory for every update. 

Another benefit is that, since the LDA implementation uses relatively large mini-batches of data, parallelism was fairly easy to get - where Gibbs samplers require considerably more work to scale across cores - let alone machines. 

I recently performed a quick run of the LDA implementation in JSAT for a talk at work, and I figured I would share some of the results. I used the data from the recent KDD cup competition at Kaggle (249,555 documents), but instead of doing prediction I simply did topic modeling of the data. I arbitrarily chose to search for 100 topics and compared my implementation to the one in MALLET. Mallet is written in Java, uses doubles instead of floats, supports a parallel Gibbs sampler, and is written by a knowledgeable researcher in the field. JSAT is also Java and uses doubles instead of floats, so I thought the comparison would be on level ground. 

First, one of the things I learned before was that using a large stop-list is critical to getting good results with LDA. Below are some of the topics I got when running JSAT's implementation with no stop-list. 


of, the, their, will, in, students, these, a, have, skills,
i, a, of, my, in, that, have, is, are, with
stage, production, theatre, watercolor, watercolors, design, kindergrtners, pan, repeating, omnikin,
instruments, band, sticks, percussion, bass, oil, playing, reeds, wood, scrapbook
the, of, in, for, have, i, this, students, our, on,
energy, earth, shakespeare, solar, cells, scientists, planet, sun, globe, ocean,
to, students, in, and, a, my, technology, for, with, are,
calculators, calculator, fish, calculus, graphs, dress, ii, high, clothes, classes,

Some of the topics kinda make sense, but I sincerely doubt the relevance of fish to calculators and graphs. Many of the useless topics contain only common words like "i, a, of" and so on. To make my comparison as fair as possible, I opted to simply use the same stop-list that MALLET used. Running the code, I then got very similar results - and matched up some of the "same" topics from MALLET and JSAT to show they are both learning the same thing. 


MALLET
JSAT
reading readers read comprehension fluency reader struggling level improve independent students, reading, read, listening, readers, center, stories, love, fluency, comprehension
language english learners spanish speak vocabulary speaking arts esl bilingual language, english, spanish, learners, speak, learning, learn, words, speaking, class,
problem students calculators problems math solving solve algebra graphing mathematics math, concepts, algebra, mathematics, understanding, real, graphing, mathematical, graph, abstract,
supplies paper basic pencils school markers year notebooks pencil colored supplies, paper, school, pencils, art, markers, pencil, basic, create, projects,
disabilities autism sensory special skills social motor fine severe emotional learning, classroom, learn, skills, technology, education, learners, interactive, lessons, disabilities,
fulfillment including cost donorschoose org www http html htm shipping fulfillment, htm, cost, including, donorschoose, shipping, org, www, http, html
equipment physical education activity play fitness active activities exercise balls active activities exercise balls physical, equipment, balls, sports, jump, ball, recess, gym, playground, fit,

While not identical, its clear they are the same topics. Since LDA is not a convex problem, and both Gibbs sampling and SGD can hit local optima - we wouldn't expect perfectly identical topics either. 

So now we know that both produce comparable results, so why choose the stochastic implementation in JSAT over MALLET? Besides the aforementioned benefits to stochastic implementations in general, below are the runtime results (sans IO) for both code bases. These were run on a machine with 8 cores. 
  • MALLET, single threaded: 37 minutes 21 seconds
  • MALLET, multi threaded: 8 minutes 25 seconds
  • JSAT, single threaded: 28 minutes 7 seconds
  • JSAT, multi threaded: 4 minutes 30 seconds
From these numbers, we can see that stochastic LDA in JSAT is faster to begin with than Gibbs sampling in MALLET. In additional, MALLET's multithreaded speedup was only 4.4x compared to 6.2x in JSAT. So we can expect better performance scalability as we throw more cores at the problem. 

Its possible to increase the scalability of JSAT further, but for now I plan on leaving the code in its current and simpler form. In general the algorithm in JSAT is very good for large corpa, but the Gibbs samplers are probably better for smaller datasets. The computational complexity will scale linearly with the number of topics. 

The stop-list issue is somewhat interesting. At least for the stochastic implementation, a non-trivial amount of work occurs in each batch update - and stop words tend to be the most common words. JSAT without the stop words removed took about 3 times longer to run. 

Its also possible to improve the results by using TF-IDF weighting instead of stop words - but that still has the slowdown since we are considering the words, and it has the issue of not mapping well to the model LDA uses. 

Currently I'm not sure what I want the API of LDA and future topic models to look like, so it doesn't implement any interface at the moment. Hopefully as I get more opportunities to apply LDA and implement other algorithms I'll figure out what is best.