Sunday, February 26, 2023

Weighted Random

Some time ago, I implemented a way to generate weighted random values from a discrete distribution in Factor. It ended up being a pretty satisfyingly simple word that builds a cumulative probability table, generates a random probability, then searches the table to find which value to return:

: weighted-random ( histogram -- obj )
    unzip cum-sum [ last >float random ] keep bisect-left swap nth ;

Is It Fast?

We can define a simple discrete distribution of values:

CONSTANT: dist H{ { "A" 1 } { "B" 2 } { "C" 3 } { "D" 4 } }

And it seems to work — we can make a few random values from it:

IN: scratchpad dist weighted-random .
"C"
IN: scratchpad dist weighted-random .
"C"
IN: scratchpad dist weighted-random .
"D"
IN: scratchpad dist weighted-random .
"B"

After generating a lot of random values, we can see the histogram matches our distribution:

IN: scratchpad 10,000,000 [ dist weighted-random ] replicate histogram .
H{
    { "A" 998403 }
    { "B" 2000400 }
    { "C" 3001528 }
    { "D" 3999669 }
}

But, how fast is it?

IN: scratchpad [ 10,000,000 [ dist weighted-random ] replicate ] time 
Running time: 3.02998325 seconds

Okay, so it's not that fast... generating around 3.3 million per second on one of my computers.

Improvements

We can make two quick improvements to this:

  1. First, we can factor out the initial step from the random number generation.
  2. Second, we can take advantage of a recent improvement to the random vocabulary, mainly to change the random word that was previously implemented for different types to instead get the current random-generator and then pass it to the random* implementation instead. This allows a few speedups where we can lookup the dynamic variable once and then use it many times.

That results in this definition:

: weighted-randoms ( length histogram -- seq )
    unzip cum-sum swap
    [ [ last >float random-generator get ] keep ] dip
    '[ _ _ random* _ bisect-left _ nth ] replicate ;

That gives us a nice speedup, just over 10 million per second:

IN: scratchpad [ 10,000,000 dist weighted-randoms ] time histogram .
Running time: 0.989039625 seconds

H{
    { "A" 1000088 }
    { "B" 1999445 }
    { "C" 3000688 }
    { "D" 3999779 }
}

That's pretty nice, but it turns out that we can do better.

Vose Alias Method

Keith Schwarz wrote a fascinating blog post about some better algorithms for sampling from a discrete distribution. One of those algorithms is the Vose Alias Method which creates a data structure of items, probabilities, and an alias table that is used to return an alternate choice:

TUPLE: vose
    { n integer }
    { items array }
    { probs array }
    { alias array } ;

We construct a vose tuple by splitting the distribution into items and their probabilities, and then processing the probabilities into lists of small (less than 1) or large (greater than or equal to 1), iteratively aliasing the index of smaller items to larger items.

:: <vose> ( dist -- vose )
    V{ } clone :> small
    V{ } clone :> large
    dist assoc-size :> n
    n f <array> :> alias

    dist unzip dup [ length ] [ sum ] bi / v*n :> ( items probs )
    probs [ swap 1 < small large ? push ] each-index

    [ small empty? not large empty? not and ] [
        small pop :> s
        large pop :> l
        l s alias set-nth
        l dup probs [ s probs nth + 1 - dup ] change-nth
        1 < small large ? push
    ] while

    1 large [ probs set-nth ] with each
    1 small [ probs set-nth ] with each

    n items probs alias vose boa ;

We can implement the random* generic to select a random item from the vose tuple — choosing a random item index, check it's probability against a random number between 0.0 and 1.0, and if it is over a threshold we return the aliased item instead:

M:: vose random* ( obj rnd -- elt )
    obj n>> rnd random*
    dup obj probs>> nth rnd (random-unit) >=
    [ obj alias>> nth ] unless
    obj items>> nth ;

It's much faster, over 14.4 million per second:

IN: scratchpad [ 10,000,000 dist <vose> randoms ] time 
Running time: 0.693588458 seconds

This is available now in the math.extras vocabulary in the current development version, along with a few tweaks that brings the performance over 21.7 million per second...

No comments: