Monday, July 14, 2008

Monadic Sampling in Scala

(If you don't know what Monads are, see here. Short answer: they're a design pattern coupled with syntactic sugar for implementing "wrappers" around an arbitrary type T.)

I was initially quite resistant to monads, probably because I didn't really grok the syntax that Haskell was pushing. Seeing Scala's for-comprehensions, which are just like Haskell's do-notation, I've changed my mind, and in an effort to teach myself Monads, I've decided to write a simple library for doing sampling for Bayesian inference. What follows is inspired in part by this paper and this blog series.

For motivation, let's look at the specification of a generative model in standard notation:
θ ~ Dir(α)
z ~ Mult(θ)
w ~ Mult(β_z)
Our goal is to make a syntax that looks a lot like that for generating new data. We'll end up with something that looks like:

for( theta <- Dir(alpha);   
z <- Mult(theta);
w <- Mult(beta_z))
yield w;
which isn't so bad.

Let's start with a simple monad in the Scala tradition:
trait Rand[T] {                                                            
def get() : T

def draw() = get()

def sample(n : Int) = List.tabulate(n,x => draw());

def flatMap[E](f : T => Rand[E] ) : Rand[E]

def map[E](f : T=>E) : Rand[E]

def filter(p: T=>Boolean) = condition(p);

def condition(p : T => Boolean) : Rand[T];
}
Thus far, nothing terribly special. Aside from normal monadic stuff and aliases, we have draw and get, which intuitively should give us a sample from the random distribution. sample is just an easy way to get many samples at once.

A default implementation for most of these methods is fairly straightforward:
trait Rand[T] {
def get() : T // still undefined

def draw() = get();

def sample(n : Int) = List.tabulate(n,x => get);

def flatMap[E](f : T => Rand[E]) = {
def x = f(get());
new Rand[E] {
def get = x.get;
}
}

def map[E](f : T=>E) = {
def x = f(get());
new Rand[E] {
def get = x;
}
}

def filter(p: T=>Boolean) = condition(p);

// Not the most efficient implementation ever, but meh.
def condition(p : T => Boolean) = new Rand[T] {
def get() = {
var x = get;
while(!p(x)) {
x = get;
}
x
}
}

}

The only difference from a normal "container" monad like Option is that we re-evaluate get() on each call to map and flatMap, to ensure that we're being random.

Now for some basic random number generators. These live in object Rand, but that's omitted here for something akin to clarity. First, we have the analogue of Haskell's return, for completeness.

def always[T](t : T) = new Rand[T] {
def get = t;
}


Straightforward enough. Now for our building block:

val uniform = new Rand[Double] {
private val r = new Random;
def get = r.nextDouble;
}


Which lets us do things like:

scala> Rand.uniform.sample(10)
res6: List[Double] = List(0.8940037286915604, 0.34021110772450114,
0.2045633072974703, 0.44871569906073616, 0.47697121133477594,
0.8410830818576492, 0.6738322287017577, 0.16060602963773707,
0.602623326916021, 0.34327615862458416)


and:


scala> val twice = for( x <- Rand.uniform) yield x * 2;
twice: java.lang.Object with scalanlp.stats.Rand[Double] =
scalanlp.stats.Rand$$anon$3@5c772046

scala> twice.sample(10)
res7: List[Double] = List(1.8602320334301579, 0.0872446976570771,
0.032309170483379335, 1.9753336995209254, 1.220452839716684,
1.0214181828533413, 1.41457180527561, 1.6988361279393165,
1.460110077486223, 0.6762038442765996)


So we already have some use of do-notation. And if that's not convincing for you, consider sampling two correlated Gaussians variables. First, we need a univariate Gaussian:

val gaussian = new Rand[Double] {
private val r = new java.util.Random;
def get = r.nextGaussian;
}

// mu is the mean and s^2 is the variance
def gaussian(m : Double, s : Double) = new Rand[Double] {
def get = m + s * gaussian.get
}
Then sampling two independent Gaussians is straightforward:

scala> val biGauss = for(x <- Rand.gaussian; y <- Rand.gaussian) yield (x,y)
biGauss: java.lang.Object with scalanlp.stats.Rand[(Double, Double)] =
scalanlp.stats.Rand$$anon$2@caa6635 scala> biGauss.sample(3)

res9: List[(Double, Double)] = List((-0.9601505823179303,-0.1480670696609196),
0.02594332256575975,0.02401831998712138),
(1.4885591927916324,1.1998923591137476))

Now suppose we want to draw to correlated Gaussians. That is, Gaussians where knowing one tells you something about the other. Drawing on this article, suppose that z1 and z2 are independent Gaussians drawn i.i.d. from the standard normal distribution, and that we interested in sampling normals (x,y) with means mu1 and mu2 and standard deviations s1 and s2, with correlation r. Then, we can draw x and y by computing:

x = mu1 + s1 * z1;
y = mu2 + r * z1 + s2(1- r2)1/2*z2
With that, we can easily write a generator in Scala:

scala> def corrGauss(mu1 : Double, mu2: Double, s1: Double, s2 : Double, r: Double) =
| for( (z1,z2) <- biGauss;
| x = mu1 + s1 * z1;
| y = mu2 + r * z1 + s2 * Math.sqrt(1 - r * r) * z2)
| yield (x,y);
corrGauss: (Double,Double,Double,Double,Double)java.lang.Object with scalanlp.stats.Rand[(Double, Double)]


And that's it! Next time, I'll write about Multinomials, Dirichlets, and generating more data.