Saturday, July 7, 2012

Probability Estimation: Nearest Neighbor

In my last post, I talked about how estimating the probability of some event \(x \) is useful to us. But I neglected mentioning how. This task is what many ML are trying to do. The Nearest Neighbor algorithm, and a very simple method that does this, and is really unreasonable accurate for many applications. 


In fact, the algorithm is easier to understand then the theory behind it. So we will start there. 


In words: if you have a training data set \(D\), composed of \(n\) data points \(\vec{x}_i\), paired with a target value  \(y_i\). Then we can classify a new data point \(q\), by finding the data point in \(D\) that is closest to \(q\), and giving \(q\) the same class. 


To be a bit more explicit, we need a concept of distance. We denote the distance between two points \(a\) and \(b\) by \(d(a, b)\). Now, we can describe the algorithm a bit more explicitly. 

$$
\text{Find } \min_i{d(\vec{x}_i, \vec{q})} \forall i \in D \\
\vec{q} \text{'s label is then the same label as } y_i
$$

Really, the intuition is very simple. We expect similar data points to be located near each other, and therefor - are likely to have the same class or value. This method works for both classification problems and regression problems. We can also improve performance by considering the \(k\) nearest neighbors (k-NN), the intuition being that if we get multiple neighbors, we wont get thrown off by one odd or mislabeled data point.  But how is this related to estimating the probability of something?

We can compute the probability density of a data set \(D\) as
$$ P(\vec{x}) \approx \frac{k}{n \cdot V} $$
where
\(k\) is the number of data points we are considering
\(n\) is the number of data points in \(D\)
\(V\) is the volume of the area.

For the k-NN algorithm, we are using this implicitly. We fix the number of neighbors \(k\) ahead of time, the data set does not change size, and then the volume \(V\) changes depending on how far away the \(k\) nearest neighbors are. If all the neighbors are very close, \(V\) will be small, and make the final probability larger. If they are far away, then \(V\) will be large, and make the probability smaller.

Why dont we use this form of the algorithm? Well, turns out - its not a very accurate measure to use explicitly. In another post, we will look at fixing the volume and letting \(k\) change as needed. Yet still, k-NN works well in practice, and has some strong theoretical results. You can use the k-NN algorithm in JSAT with the NearestNeighbour class.

So are there any down sides to k-NN? Yes. Unlike most algorithms, k-NN shifts all of its training work to extra work that needs to be done durring classification. The naive implementation has the issue that we must example every data point to make a decision, meaning it will take \(O(n)\) work to compute a single result! To tackle this issue, JSAT separates out the k-NN algorithm from the process of finding the nearest neighbor of a query into the VectorCollection interface. A VectorCollection knows how to search a data set to find either the \(k\) nearest neighbors, or all neighbors within a certain radius.

Despite faster methods for getting the nearest neighbors, k-NN can still be very slow, and cant always be used for large data sets.


No comments:

Post a Comment