Thursday, October 16, 2014

Stacking

I've recently added Stacking as an Ensemble to JSAT. Unlike Boosting or Random Forests, Stacking is very general - it can be applied to any models (the weak learner in boosting must support weighted data, and RF are just special trees). It's been on my TODO list for a long while now, and so far I've gotten pretty good results with it.

The phenomena of simply combining the results of multiple models to improve performance is well known, and many competition submissions are just the averaged (or majority vote) predictions from several created models. The difference with stacking is that it learns a weighted vote based on the performance of each classifier, which should perform better than any individual model in the ensemble (and better than a naive averaging).

However, there is no reason why the combing phase has to be a linear model. So in my implementation the base aggregating model can be specified. Though I have been using linear models for my toying around.

The other thing I've not seen mentioned is that Stacking can be done in an online fashion. By updating the aggregating model before the models being ensembles, you can get unbiassed updates to the aggregator. I've been combing this with Passive Aggressive models as the aggregators and been getting some great results. I use the PA models because they adapt and learn very quickly, and don't need any parameter tuning to work well.

To visually see that it was working, I created a simple ensemble of 3 models on a 2D dataset. My base models are online logistic kernel machines, with 3 different but poorly chosen RBF kernels. One is far too large, one is just a bit too large, and the final one was too small that it overfit.



Then I applied my online Stacking implementation. As you can see, it clearly learns a model that is better than any individual model from the ensemble. This can be very useful in scenarios that don't require real time predictions.

Online Stacking of 3 Logistic Kernel Machines

In doing this I explicitly used a Logistic model because it produces output probabilities. This makes a huge difference in the performance of Stacking. Using the hinge loss just doesn't work as well, as show below.

Online Stacking of 3 Online SVMs

This is because with the hinge, you only get a "yes/no" predication and we need to learn the weight for each model based on that alone. If a model knew that it dosn't know what to predict and returned probabilities indicating such, we could learn a better weight that takes into account the confidence in predictions. In the case of the model 'not knowing' its vote will get washed out when making a prediction by contributing equally to all classes - and we can learn to exploit this. Here we show the probabilities of the 3 models, and it becomes much easier to see how they could be combined in a way that improves the overall accuracy, giving the above result.


The wide model gives us the votes for the top and left/right sides of the space, the middle model gives us most of the edges around the data, and the left model gives us the votes near the borders of the classes (lightened by the smoother border of the middle model).

One could also imagine learning rules to exploit various correlations between model's predictions. This is why I've made the aggregating model configurable rather than force everyone to use linear aggregators.

Sunday, October 5, 2014

Beginner Advice on Learning to Implement ML Algorithms

Jason Brownlee contacted me recently to ask if I could give my advice/opinion on a few questions (prompted by a post of mine on reddit). I'll be answering them here. He also asked about some optimization tricks which I'll answer in a later post.

How can you implement algorithms from scratch to learn ML?


The short, if unsatisfying, answer is practice. When you read a new algorithm or paper, there will be a lot of assumed knowledge. The authors have only a handful of pages to convey a new concept to the reader, so its up to you to read all their references and learn the field well enough that you can read in-between the lines, learn the tricks, and translate higher level papers to lower level code. Some of this will be a bit circular in logic, but the truth is practice is an iterative process. You have to go through everything in cycles and ‘bootstrap’ yourself up from confusion to understanding. The beginning may not be the hardest part, depending on what kind of methods you are interested in, but it will be the most discouraging period. That said, here are the “steps” I generally follow when implementing (or trying) a new algorithm.

1. Read the whole paper, no skimming. Then wait. There is little reason to believe you will grasp everything at first, and re-reading over and over immediately is just going to fatigue you and desensitize you to the words. You need time to let the paper sink in, think about how it relates to other algorithms and ideas, try to build mental model of everything that is happening. Then you go back and re-read the paper, enhance your understanding, and repeat until you reach diminishing returns.

Obviously different algorithms require different amounts of mental effort – a few of my ‘TODO’ papers have been on my list for a few years. But not matter how I feel, or how much I want to just ‘dive in’ and start writing some code – in the end I've never regretted doing this process. Overtime, it also builds up a much deeper understanding of how different algorithms are related and that’s when you can start coming up with your own ideas.

2. Come up with (or find) the simplest problem that is the most complicated that the algorithm can solve perfectly. This may be a bit confusing to read, but essentially you want a toy problem that will force the algorithm to exhibit desirable behavior but will also allow you to get a consistently perfect result.

Doing this often requires some level of understanding of the algorithm in the first place, so in some cases can be a bit tricky. Besides helping to make you think about the algorithm and what it is actually solving, it is also a huge boon when developing and testing your implementation. Especially if you can visualize the problem and solution. For this reason I often create 2D problems for initial development.

First this provides a useful unit test of macro functionality. If you ever go back to improve / refactor / modify your code, it may catch you breaking something on accident.

Second, it can help you catch “near working” implementations. I used to always start with replicating results on common benchmark datasets, such as MNIST. But it is possible to implement an algorithm but fail to account for certain corner cases or code that doesn't quite solve the right problem, but happens to return good accuracies. This is a particularly notorious problem in Neural Networks, and even happens at the algorithm level in top academic journals.

3. Speed comes last. There are lots of performance tricks to get code and ML type algorithms moving faster, but they often clutter the code. When initially developing, it’s more important to get the logic right. This is part of the observation that debugging code is much harder than writing it. If you attempt to be fast before you have confirmed the logic is working, you will have to first determine if it is a problem with the algorithm itself, or the optimization you've done. It’s an extra step that simply isn't needed. Its also much easier to validate optimizations as both correct and a improvement when you can compare them to a base implementation.

4. Know what you are built upon. Given a Linear Algebra library (BLAS/LAPACK), a number of algorithms become very easy to implement efficiently. You should almost always use these instead of implementing them from scratch like I do. My purpose for going from scratch is self understanding and education. In reality this provides almost not practical benefits for my library or code. However, if you do implement an algorithm ontop of these tools - keep in mind that you are using a tool that likely has over 30 decades of built up knowledge, algorithm development, testing, and performance chasing. Try implementing the methods you call from scratch when you have a chance and see how it impacts performance, just so you can get an appreciation of how much work is done for you.

What are some traps that you see?

Most all coding advice applies here (especially any warning about floating point), but I will mention a few ML centric mistakes I often see (in general and for beginners). 

1. Don't assume the paper is perfect or even correct at all. I've implemented a number of papers that had significant errors in the paper that needed to be fixed before the algorithm worked. Peer review is far from perfect, and serious errors will get by. 

I've attempted a few papers I was never able to reproduce the results of. Some were simply missing too many details, but seemed like they could work. Sometimes there are minor mistakes in the paper, or an equation that never got updated/corrected in the final version. Rarely I've suspected papers of having serious issues in the algorithm or evaluation. 

2. Don't try and get a “math free” understanding. This is a particular problem for beginners – especially those weak in math. You can’t understand what you are implementing without the math. You can often implement it without a full understanding, and that’s fine – you can't master all of everything. But don't try to get by with none of it. 

If you really want to implement ML algorithms but feel too lacking in the math department, just keep iterating. Try your best, learn, and be frustrated. Eventually it will get a little easier. I personally feel very weak in my math skills, but reading back the same papers I read 3/4 years ago and I'm amazed at how much more I understand and can comprehend. 

3. Don't start with other people’s source code! I'll let you read my reddit comment for the whole spiel. But the short of it is that most people don't actually learn anything by reading the source code if they don't understand what the algorithm is in the first place. 

4. A common question I see is some variant of “how accurately could ML predict X?” or “how would I apply ML to doing X?” While it may seem reasonable to someone new to the field, these are very poor questions to ask. 

In some instances, if the problem is very similar to one that has been solved before, you could point to those previous results as a guess. With more experience, you might be able to hazard a good guess if you know some details about the problem itself. But there are a huge number of variables about the problem that someone isn't going to know about your data, or domain expertise on the problem they simply don't have. 

Even if they did have this information, there is application context that is missing. Does the solution need to run in real time? Are the resources for training the model limited? Does the model have to run on an embedded device? Is ‘accuracy’ the real goal, or is it something else? If accuracy is the goal, what level of accuracy is needed? Are reasonable probabilities needed? Will the model be updated over time? Creating and applying a Machine Learning solution has to take in a lot of factors. 

Instead, as you are learning, go look for already existing ML solutions and read about how they were done or how they work. Once you've learned and seen a number of solutions and the work involved, you can hopefully apply some of those thoughts and ideas to your own problems. 

5. Using the default random number generator. This may seem like an odd one, but it’s a bit nuanced. A number of machine learning algorithms rely on randomness. So you use the PRNG built into your language of choice and everything is great. Except sometimes not, but that’s how randomness works, right? But a number of languages and libraries (especially Java) have very poor quality PRNGs as the default. While there is no need for cryptographically secure PRNGs, switching to a decent quality generator can save a lot of headache debugging behavior that is very difficult to diagnose as coming from the PRNG. I've only had this really impact me a few times, but switching away from the default has been a big help. This is particularly important if you plan on using an algorithm that revolves around randomness. 

6. Using Python or Ruby or some other interpreted language. This is some of my own personal bias here, but really implementing an algorithm efficiently isn't a fun exercise in these languages. You need a faster JITed or fully AOT compiled language so you can figure out where your code is really slow. There are ways around this in Python & friends, but at that point you learning more about implementing in Python, rather than the algorithms themselves

What resources would you suggest?

This is a hard question to answer, because it depends a lot on what your interests are and your background is. A lot of people get into ML from other fields because its useful and has a lot of interesting parts. Linguists are involved a lot in the Natural Language Processing side, I know of a number of physicist and math guys who like some of the more applied and theoretical parts. Topic modeling is evidently growing as a tool for History and English departments to analyze large swaths of text that would be too time consuming to do as an individual. So I'm going to instead list some resources and people I like. 

In my experience, replicating results is often the best way to get good at implementing algorithms. So what I'm considering a 'good' resource for getting better at implementing ML algorithms is a a combination of 3 items

  • Pseudocode description of algorithm. Some papers scatter the details of the algorithm out, making them a bit hard to follow 
  • Explanation of math details for the algorithm. Pseudocode alone isn't always enough, and can be very vague. While you should hopefully get to a point where a derivative or the form of a function could be done yourself, to start its a detail you probably want made explicit. 
  • Reproducible results. Many papers gloss over important details, like parameters tested and how parameter selection is done. But without that information you can't re-create the results to confirm your code works! Even better is when code is provided so you can compare results (not just copy their code). 


Online Course:

You've probably found it or head of it by now. But Andrews Ng's Machine Learning course on Coursera is quite good for raw beginners. He also takes you through an example of decomposing a problem (OCR) as a series of steps that can be solved with ML algorithms. This is most likely the place to start if you are a complete beginner. 

Books:


This is a very common book for people to learn and get, especially since it is free. In my opinion, the book isn't particularly great in any area – but its not particularly bad at anything either.  Though if you are more of a visual learner this book has lots of graphs and diagrams to try and give intuitions of how things work / what's going on. If your goal is to implement for learning, recreating their visual results would be an excellent exercise. 


This is currently my favorite ML book. Murphy does an excellent job explain the algorithms, and relating them to each other to help foster a deeper understanding. Though the learning curve is a bit wonky at times, and the later chapters don’t have quite as much lower level details – its overall excellent. For implementing in particular, much more refined (and explained) pseudo code is present for many of the algorithms and for many of the chapters and algorithms goes through the math needed to develop these algorithms. Occasionally some implementation considerations are discussed. While not often, it is more than most books on Machine Learning.  

People in General: 


Chih-Jen's group consistently puts out very high quality work, including the widely used LIBSVM and LIBLINEAR packages. Their papers are especially good if you are looking for algorithms to implement. He always describes the algorithm in pseudo code with a good amount of detail, including implementation details and choices made with the implementation in mind. They are also meticulous at making their work reproducible, and providing very detailed experimental procedure so you can replicate it with their code or your own. 


Almost every paper I've read that has had Nathan as an author has also provided good pesudo code to follow and implement from. 


Not a huge number of papers, but I've generally found his very concisely described if you are just interested in implementing it. 


Perhaps too ambitious for beginners, Pedro has a lot of very interesting papers that are well described. Not always, but he also provides some code implementations that you can compare against. 

Some papers in particular: 

Bob Carpenter's Lazy Logistic Regression is a good paper to learn from, and goes through the details and derivation of a logistic regression algorithm. He also has a good blog. 

Greg Hamerly's Making k-means even faster is a great algorithm in general (every library's k-means implementation should switch from lloyds to this as the default). But the paper itself is also unique in covering a lot of the implementation tricks used to make it fast (an unfortunate rarity). 

Lawrence Cayton's Accelerating Nearest Neighbor Search on Manycore Systems is an algorithm for accelerating Nearest Neighbor queries. This one doesn't have any pseudo code, but the paper is detailed and also talks about implementation considerations. Its also a good exercise when you get more comfortable to extend the algorithm from just nearest neighbor queries to range queries. 

Hal Daume's Notes on CG and LM-BFGS Optimization of Logistic Regression is probably too much for a first attempt. But its a good starting place for your first really complicated algorithm, as he has assembled all the parts together for you. You'll still be missing a lot of the theory and math, as the paper alone isn't sufficient background for what's being implemented. But its a very practical goal oriented discussion and detailing. Implementing it will give you a good appreciation for the amount of work in writing a debugging more significant code.  

Thursday, September 11, 2014

Kernel Methods Aren’t Dead Yet

The other night I gave a talk at the Data Science MD meetup group about using Kernel Methods on larger datasets. I have the slides up here. Public speaking is really outside my comfort zone, but its something I want to do / get better at. Overall it was a great experience, and my first experience giving a talk about Machine Learning in front of a large group. I'm quite sure I left a good deal of room for improvement.

The majority of my talk was about methods that could be used to approximate the solution of a support vector machine, specifically forming approximate features and using a linear solver or performing bounded kernel learning using projection and support vector merging. The motivation being that these approximate methods are good enough to find the parameters we want to use, and then a slower more exact solver can be used on the final parameters chosen.

Part of what motivated this idea is that many people feel that they have to use the exact same algorithm for all steps. There is simply not reason for that, and the grid search is a perfect example of this. If a pair of parameters C and \(\sigma \) are going to perform badly using the exact solver, they aren't going to get any better with an approximate solution. So the approximate solution is more than good enough to filter these out.

One could envision a tiered approach. In the first stages very fast but low quality approximations are used, whittling down the set of all parameters to try down to a smaller set. Then a more accurate approximate to cut out more. Continuing until we feel that a 'best' pair can be selected, and the most accurate (or exact) solver gets us the final model.

My other motivation for the talk is (my perceived) dying out of using SVMs. A lot of people now seem to be trying to use either only linear models or Neural Networks. Both certainly have their place, but I feel kernel methods are a more easily applied than Neural Networks while having significantly wider applicability than simple linear models. One of many tools in the tool box thats just collecting dust when it could be solving problems.

For me personally, it was good experience to talk to a wider audience / range of skillsets. The number of black stairs I got from the audience makes me think I may have lost a few people. So in future versions I think I'm going to try and add more slides that give the intuition of what's going on. I'm not sure how to best give the intuition for Random Kitchen Sinks if the mathy version doesn't work, so that will take some thought.

I was also quite nervous, and went through my slides a bit too fast. I try hard to not be that guy reading their slides word for word, and in doing so forgot to talk about some things that weren't in my slides. Hopefully more practice will decrease the nervousness in the future.

Sunday, June 29, 2014

Stochastic LDA

Not too long ago I finally implemented an algorithm for Latent Dirichlet allocation, which has become one of the most popular topic modeling techniques. While most implementations you will find online are some form of Gibbs sampling, a stochastic version was what interested me. Specifically the algorithm described in Hoffman, M., Blei, D., & Bach, F. (2010). Online Learning for Latent Dirichlet Allocation. In Advances in Neural Information Processing Systems (pp. 856–864). I had tried implementing this a few times before with little success, until I found some extra notes on performing the update in a sparse manner here

Since LDA is often going to be used on unlabeled data, which is very abundant - I wanted an implementation that would work for extremely large corpa. Using a stochastic algorithm, the new data points can be streamed in from disk as needed - where as a normal Gibbs sampler requires all of the data to be in memory for every update. 

Another benefit is that, since the LDA implementation uses relatively large mini-batches of data, parallelism was fairly easy to get - where Gibbs samplers require considerably more work to scale across cores - let alone machines. 

I recently performed a quick run of the LDA implementation in JSAT for a talk at work, and I figured I would share some of the results. I used the data from the recent KDD cup competition at Kaggle (249,555 documents), but instead of doing prediction I simply did topic modeling of the data. I arbitrarily chose to search for 100 topics and compared my implementation to the one in MALLET. Mallet is written in Java, uses doubles instead of floats, supports a parallel Gibbs sampler, and is written by a knowledgeable researcher in the field. JSAT is also Java and uses doubles instead of floats, so I thought the comparison would be on level ground. 

First, one of the things I learned before was that using a large stop-list is critical to getting good results with LDA. Below are some of the topics I got when running JSAT's implementation with no stop-list. 


of, the, their, will, in, students, these, a, have, skills,
i, a, of, my, in, that, have, is, are, with
stage, production, theatre, watercolor, watercolors, design, kindergrtners, pan, repeating, omnikin,
instruments, band, sticks, percussion, bass, oil, playing, reeds, wood, scrapbook
the, of, in, for, have, i, this, students, our, on,
energy, earth, shakespeare, solar, cells, scientists, planet, sun, globe, ocean,
to, students, in, and, a, my, technology, for, with, are,
calculators, calculator, fish, calculus, graphs, dress, ii, high, clothes, classes,

Some of the topics kinda make sense, but I sincerely doubt the relevance of fish to calculators and graphs. Many of the useless topics contain only common words like "i, a, of" and so on. To make my comparison as fair as possible, I opted to simply use the same stop-list that MALLET used. Running the code, I then got very similar results - and matched up some of the "same" topics from MALLET and JSAT to show they are both learning the same thing. 


MALLET
JSAT
reading readers read comprehension fluency reader struggling level improve independent students, reading, read, listening, readers, center, stories, love, fluency, comprehension
language english learners spanish speak vocabulary speaking arts esl bilingual language, english, spanish, learners, speak, learning, learn, words, speaking, class,
problem students calculators problems math solving solve algebra graphing mathematics math, concepts, algebra, mathematics, understanding, real, graphing, mathematical, graph, abstract,
supplies paper basic pencils school markers year notebooks pencil colored supplies, paper, school, pencils, art, markers, pencil, basic, create, projects,
disabilities autism sensory special skills social motor fine severe emotional learning, classroom, learn, skills, technology, education, learners, interactive, lessons, disabilities,
fulfillment including cost donorschoose org www http html htm shipping fulfillment, htm, cost, including, donorschoose, shipping, org, www, http, html
equipment physical education activity play fitness active activities exercise balls active activities exercise balls physical, equipment, balls, sports, jump, ball, recess, gym, playground, fit,

While not identical, its clear they are the same topics. Since LDA is not a convex problem, and both Gibbs sampling and SGD can hit local optima - we wouldn't expect perfectly identical topics either. 

So now we know that both produce comparable results, so why choose the stochastic implementation in JSAT over MALLET? Besides the aforementioned benefits to stochastic implementations in general, below are the runtime results (sans IO) for both code bases. These were run on a machine with 8 cores. 
  • MALLET, single threaded: 37 minutes 21 seconds
  • MALLET, multi threaded: 8 minutes 25 seconds
  • JSAT, single threaded: 28 minutes 7 seconds
  • JSAT, multi threaded: 4 minutes 30 seconds
From these numbers, we can see that stochastic LDA in JSAT is faster to begin with than Gibbs sampling in MALLET. In additional, MALLET's multithreaded speedup was only 4.4x compared to 6.2x in JSAT. So we can expect better performance scalability as we throw more cores at the problem. 

Its possible to increase the scalability of JSAT further, but for now I plan on leaving the code in its current and simpler form. In general the algorithm in JSAT is very good for large corpa, but the Gibbs samplers are probably better for smaller datasets. The computational complexity will scale linearly with the number of topics. 

The stop-list issue is somewhat interesting. At least for the stochastic implementation, a non-trivial amount of work occurs in each batch update - and stop words tend to be the most common words. JSAT without the stop words removed took about 3 times longer to run. 

Its also possible to improve the results by using TF-IDF weighting instead of stop words - but that still has the slowdown since we are considering the words, and it has the issue of not mapping well to the model LDA uses. 

Currently I'm not sure what I want the API of LDA and future topic models to look like, so it doesn't implement any interface at the moment. Hopefully as I get more opportunities to apply LDA and implement other algorithms I'll figure out what is best. 


Tuesday, February 18, 2014

Fail Fast for Model Building?

Fail Fast is a concept that has been around for a while, and is something I'm very fond of for software development. A simple way to state the goal for a programer is as follows: Attempt to throw an error sooner rather than later.

This doesn't mean to simply throw errors because you can. What we want is to identify situations where we know that the procedure can't possibly succeed. Instead of waiting for the process to fail on its own, we will fail the process (or throw an error or alert someone) the instant we know that failure is the only option.

But what about model training? Can we develop models where we know that model is in a bad spot, and terminate early? The thought came to me as I was implementing the AMM algorithm, as the published results seemed incredibly good. Not quite too good to be true, as I hadn't seen much other discussion of the paper. But something was missing. You can see their reported times here, and the AMM algorithm is almost as accurate as the RBF SVM while being only a bit slower than the linear SVM. But those times are just for the final model, how did the select the parameters?

The answers, as it often is - grid search. When I rand the same algorithm through Grid Search I quickly discovered the flaw, that AMM takes an order of magnitude or 2 longer to train when given bad parameters for the specific data. So can we really say that AMM is so fast? On new data we will have to run grid search - and it is going to take us at least as long as the slowest model to get all of our results.

This isn't a problem unique to the paper I've singled out. SVMs runtime is very heavily influenced by its regularization parameter C. Budgeted kernel methods exhibit the same behavior, so its not an issue of the number of support vectors. Even L1 regularized linear models in particularly are enormously slow when using very little regularization.  To me there seem to be 2 possible solutions.

  1. Develop flexible non-linear models that more strictly bounded computational complexity independent of the model's regularization. 
  2. Develop methods of detecting early when model will fail to reach a good solution
The first is, I think, the preferred option. Though I'm not entirely sure what it would look like. In some sense, budgeted kernel methods like the Forgetron already fit this description. But they don't quite fit the bill. We still need to determine the budget size and even when the budget is small to be fast, computational cost can be much higher when the model makes errors due to bad parameters (example, consider a projection step that only gets done on errors). 

The Computer Scientist option is the second one, which is what I'm talking about. I want to run grid search and have some of the models fail fast, where they simple stop training and report to the grid search that the parameter set was bad. 

I think some preliminary easy gains could be achieved with some work. Running a faster linear model at the first step gives us a good upper bound on our desired error rate. If the model thinks it wont be able to beat the error, why bother continuing? We have a linear model that is simpler to deal with that can reach the same goal. So go ahead and fail, even though we would have eventually reached some solution. 

The immediate option is to estimate our error periodically on a held out validation set, and give up if we don't see significant gains in performance over time. But this wont always work. This paper by Cotter et al shows that wouldn't always work for every model - as SMO based SVM algorithms don't start reaching a good solution until very close to the final solution. So such a strategy would be limited to models that have useful intermediate solutions. 

Perhaps someone knows of some similar work that is currently out there, but I think its a thought process worth investigating for the future. 

Thursday, January 23, 2014

An Overview: Implementing the Digamma Function

I've been busy getting used to my full time job, which has graciously given me permission to continue working on JSAT - so I haven't done much updating recently. I've been particularly neglectful of this blog. So to compensate I've decided to do a bit of an overview on implementing a special function. In particular, how I came to the digamma function that I added to JSAT in anticipation of some other items on my TODO list.

 I've always enjoyed numerical analysis and implementing functions, though implementing them is never something that was really "taught" in my numerical analysis courses. We learn to analyze them and determine the potential for errors and all that, but how do we actually take a function and implement it? So for this I'll walk through what I did to create mine. By no means am I an expert at this - I usually prefer to follow papers of known good implementations, but is fun to do sometimes necessary.

Besides the wikipedia page, there are 2 other very useful websites when doing this kind of work. First is obviously WolframAlpha. The other is also from Wolfram, but is the less known functions.wolfram.com site. If you are approximating a well known function (such as the digamma function) that is listed there, they have a ton of properties and series already listed.

First thing is first, we plot the function we are interested in to get an idea of what its all about.



So, first thing we notice is that everything < 0 is clearly not a fun place to be. We would prefer to try and approximate areas that don't change much and don't change quickly. And if they do change, it would be nice if they changed in a consistent way. Occasionally we can get around such situations by looking for symmetry, a point where we can wrap a negative input value back into the positive space where things will be easier. This is where functions.wolfram comes into play.

If we pay attention, we realize that the left kind of looks like a \( tan(x) \) put through a transformation. So if we were to approximate it directly, I'd want to try and find a set of extra terms that would make \( tan(x)\) behave more like what we see - and then try to optimize those coefficients. But in this case we can use symmetry since (from wolfram)

$$ \psi(1-x) = \psi(x) + \frac{\pi}{tan(\pi x)} $$

Using that we can shift any negative value of \( x \) into the positive range.

While we are shifting values, we can also see that small positive values of \( x \) are rapidly approaching \( -\infty \). We again use another relation

$$ \psi(x) = \psi(x+1) - \frac{1}{x} $$

To shift away from the very small positive values. I decided to shift all values so that I only ever evaluate \( x \geq 2 \). This way I can focus on areas that appear to be much smoother and easier to deal with.

Now begins the approximation part. My preferred method is to use a continued fraction, as they often have an excellent radius of convergence. If the terms are fairly simple to compute (or event better, just coefficients) the continued fraction and also be very quick to compute. Unfortunately, WolframAlpha tells us continued fraction for the digamma involves evaluating the Riemann zeta function. So that would be too computationally expensive.

So we look for a series, and we note that the increase of \( \psi(x) \) is similar to that of \( \log(x) \). And indeed, we find the series expansion

$$ \psi(x) = \log(x) - \frac{1}{2 x} + \sum_{n = 1}^{\infty} \frac{\zeta(1-2 n)}{x^{2 n}} $$

Obviously we can't just keep summing until its accurate enough, because that would require us to evaluate the Riemann zeta function. We could either try and find the number of terms needed to be accurate for all \( x > 2\), or use a truncation of the series and try to correct for the error.

I chose the latter option, and truncated at

$$ \psi(x) \approx \log(x) - \frac{1}{2 x} - \frac{1}{12 x^2} $$

since this seemed to get me the most out of the diminishing returns in accuracy. In addition, I liked that it is actually a lower bound on \( \psi(x) \) for \( x > 0\). At this point, the first thing I did was try and find the point at which the difference was small enough that I could just accept the approximation as close enough. By \(x = 500 \) the error was on the order of \( 10^{-15} \) or less, so that was my cut off.

That left me with the range \( x \in [2, 500] \) that I needed to approximate. Using some software to do MiniMax Approximations I tried to accurately approximate as much of the range as I could with a relatively small polynomial correction term. I started with the end of the range being 500 and decrease \( x \) as much as I could. This is because \( \psi(x) \) is behaving better and better as we increase x, so I want to get as much of the easy range as possible before having to do something more convoluted or clever. I was able to get \(x \in [70, 500] \) with one polynomial and \(x \in (7, 70) \) with a second.

This left me with only \(x \in [2, 7] \) to get. I was aiming for high accuracy throughout all value ranges, any the code I was using for the MiniMax approximation wasn't working well in this range. I tried using some simpler tricks like a Pade approximation centered at 4.5, but the radius of convergence wasn't quite large enough (as is often the case).

At this point I started looking around for a bit of help, and found this paper. For small positive values (in a different range) they used

$$ \frac{\psi(x)}{x-x_0} $$

where \(x_0 \) is the only positive value such that \( \psi(x) = 0 \). This turns out to be easier to approximate. While the paper notes having to do extra work due to numerical stability issues, because I used the trick in a larger than I did not need to apply any extra work.

All of that work combined got me my digamma function. While I've not done any performance tests, accuracy wise it seems to be very good ( usually less than \(10^{-14}\) ) unless you value near one of the asymptotes.

You may feel that parts of this were a bit hand wavy. However, I feel most of the really difficult parts in implementing such functions are more about having the intuition behind the behavior. Being able to spot a pattern look for (or derive) a relation that will help you, or transform the function into a better behaved form.

Monday, July 22, 2013

L1 Regularization with Superfluous Features

I've been working on adding some more Linear algorithms again, with a focus on making sure there was some L1 regularized learners in JSAT. This was the one area JSAT was really lacking in, and there really are not a lot of great algorithms for it. A number of the published papers I attempted implementing didn't even work well outside of a very narrow range.

The two algorithms I did implement that worked well were Binary Bayesian  Regression (BBR) and Sparse Truncated Gradient Descent (STGD). The former is a batch Logistic Regression algorithm and the latter an online algorithm for the squared loss. Both of which support L1 regularization and worked well for me once implemented. 

I justed used BBR to create a fun graph. Consider a data set where all features are drawn from independent N(0, 1). Let only 5 of these features be relevant, and restrict yourself to only 50 data points. How do you find the relevant features as you keep adding more irrelevant dimensions? Unregularized learners will quickly degrade to random guessing or worse. L2 regularization is easy enough to do quickly, but still isn't strong enough to learn the right coefficients as the dimensionality increases. 

This is where L1 regularization comes into play. While most papers mention the sparsity property of L1, the real power is that L1 gives us theoretical bounds on its performance in the face of irrelevant features. This is incredibly important as we collect more and more data with lots of features, where we don't really know which of the features are useful for our decision. 



Its a fun plot to look at, and you can clearly see the L1 prior is helping maintain a reasonable level of performance. This graph is slightly unique in how I made the problem. Most L1 vs L2 graphs like this show the L1 prior doing much better, staying near the original accuracy. Often they construct the irrelevant features as sparse, with a binary 0/1 value with a small probability of being 1. The performance is still good, but its important to emphasis that the L1 prior isn't impervious to random features.

Below is the code to generate the same data, including some more that I didn't include. You can try it out to see how the regularization effects the problem, and even change the problem to the sparse 0/1 version and see how it holds up. Note: that BBR takes a while to converge for very small regularization values (\(\lambda \leq 10^{-4}\) for me) for the L1 prior.


import java.text.DecimalFormat;
import java.util.Random;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.linear.BBR;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.utils.random.XORWOW;
import jsat.regression.LogisticRegression;

/**
 *
 * @author Edward Raff
 */
public class L1L2Comparison
{

    /**
     * @param args the command line arguments
     */
    public static void main(String[] args)
    {
        
        int relevant = 5;
        int classSize = 50;
        int[] irrelevant = new int[]
        {
            0, 5, 15, 45, 95, 245, 495, 
        };
        
        double[] regPenalties = new double[]
        {
            1e-4, 0.001, 0.01, 0.1, 0.5, 1.0, 5.0, 10.0
        };
        
        Random rand = new XORWOW();
        
        double[] coef = new double[relevant];
        
        for(int i = 0; i < relevant; i++)
            coef[i] = rand.nextDouble()*10;
        
        DecimalFormat df = new DecimalFormat("#.#########");
        
        for(int i = 0; i < irrelevant.length; i++)
        {
            int D = irrelevant[i]+relevant;
            ClassificationDataSet cds = new ClassificationDataSet(D, new CategoricalData[0], new CategoricalData(2));
            for(int k = 0; k < classSize; k++)
            {
                Vec xP = new DenseVector(D);
                for(int j = 0; j < D; j++)
                    xP.set(j, rand.nextGaussian());
                double result = 0;
                for(int j = 0; j < relevant; j++)
                    result += coef[j]*xP.get(j);
                if(result > 0)
                    cds.addDataPoint(xP, new int[0], 1);
                else
                    cds.addDataPoint(xP, new int[0], 0);
            }
            
            System.out.println("\n\nD: " + D);
            LogisticRegression lr  = new LogisticRegression();
            
            ClassificationModelEvaluation cmeLr = new ClassificationModelEvaluation(lr, cds);
            cmeLr.evaluateCrossValidation(10);
            
            System.out.println("UNIFORM: " + df.format(cmeLr.getErrorRate()));
            
            System.out.print("REG: ");
            for(double reg : regPenalties)
                System.out.print(df.format(reg) + ", ");
            System.out.print("\n L2: ");
            for(double reg : regPenalties)
            {
                BBR bbr = new BBR(reg, 1000, BBR.Prior.GAUSSIAN);
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(bbr, cds);
                cme.evaluateCrossValidation(10);
                System.out.print(df.format(cme.getErrorRate()) + ", ");
            }
            System.out.print("\n L1: ");
            for(double reg : regPenalties)
            {
                BBR bbr = new BBR(reg, 1000, BBR.Prior.LAPLACE);
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(bbr, cds);
                cme.evaluateCrossValidation(10);
                System.out.print(df.format(cme.getErrorRate()) + ", ");
            }
            
        }
    }
}