How LLMs got me into sampling theory
by Manuel de Prada Corral
6 min read
This WIP post collects some of my notes on sampling theory. My end goal is to have a global understanding of importance sampling and Horvitz-Thompson estimators.
Sampling from a probabilistic model can serve many purposes. The obvious one is to generate samples, such as images, text, or audio. However, we can also use sampling to compute expectations, such as the expected value of a function of the samples.
These notes were made with sampling from a Language Model in mind, but they are applicable to many autorregressive models. The key idea is that we cannot sample from the model directly, but we have to recursively sample the next word from the conditional distributions1.
Hurray! We got a sample from the model! However, as it happens with LLMs, often the samples are not good enough, in the sense of what humans judge as "good text".
In practice, when sampling from LLMs, often one of the steps yields an unlikely word, ruining the whole sample. Having a good calibration for unlikely words is very difficult, and there are ad-hoc interventions such as sampling adaptors. Another approach is to avoid sampling altogether, and instead just deterministically search for the most likely sequence of words, which is called beam search2.
But, what if there was a principled way to get better samples?
TODO: Introduce utility function, minimum bayes risk, and importance sampling...
Sampling with replacement from the categorical distribution
In machine learning, we often fit a model to produce a vector of unnormalized log-probabilities , and we need to sample from the corresponding categorical distribution.
The naive approach
To sample from the categorical distribution, we can use the inverse transform sampling method:
- Normalize the log-probabilities (complexity ).
-
Compute the cumulative distribution function (CDF) as .
-
Sample from the uniform distribution .
-
Finally, we pick the biggest index such that . (complexity by binary search, since the CDF is sorted).
In total, this naive approach has complexity , where k is the number of samples.
Proof of inverse transform sampling
We are interested in sampling from a random variable using a random variable , so we need to find a function such that . Now,
Hence, , and we can sample from by sampling from and applying .
Numerical stability considerations
When computing the normalizer in step 1, the sum might overflow. We can avoid this by subtracting the maximum value from all the log-probabilities, so that the maximum value is zero, i.e., .The Gumbel trick
The Gumbel trick allows to sample from the categorical distribution without computing the CDF.
-
Sample independent Gumbel variables . This can be easily done using the inverse transform sampling: .
-
Compute the perturbed log-probabilities . Since the Gumbel distributions are a location-scale family, .
-
Finally, we pick the biggest index such that for all . (complexity since they are not sorted). In other words, we take the index
Proof
Lemma 1: The inverse CDF of the exponential distribution is
Lemma 2: If are independent, then and
Observe that the probability of a tie is zero.
Proof. We want to prove that the probability of picking the index is , i.e., the probability of being the biggest perturbed log-probability is .
First part: show that :
Recall that are the unnormalized log-probabilities, we can write
Hence, , which is the inverse CDF of the exponential distribution with parameter (Lemma 1).
Using inverse transform sampling, since , we have that .
Second part: Note that
Third part: Finally,
where in the last step we have used Lemma 2.
The top-k Gumbel trick
We just saw how the Gumbel trick allows to sample from the categorical distribution, by computing
TODO: how Maddison et al proved that max and argmax are independent, and that taking the top-k is equivalent to sampling withouth replacement.
Computing expectations
In order to compute expectations, if we don't have any domain-specific closed-form expression, we typically resort to Monte Carlo (MC) estimation. This involves sampling times from the model, and computing the average of the function of interest :
The intuition is simple: in a discrete world, the most probable samples will be sampled more often, so the average will be close to the expectation.
The Monte Carlo estimator is unbiased, but it has high variance. To compensate, we would need more samples, which is often infeasible. Also, if the distribution has low entropy, we will be inneficiently sampling the same values over and over again.
Importance sampling
Importance sampling is a technique to reduce the variance of the Monte Carlo estimator. The idea is to sample from a different distribution , and then reweight the samples by the ratio of the probabilities:
How sampling theory can help
TODO: stratified sampling, Horvitz-Thompson estimator, weighted reservoir sampling, priority sampling, sparse vector representations.
Bayes Risk and MBR
Sampling without replacement from a LLM
Appendix: a rigorous definition for the notation for discrete distributions
A categorical (generalized Bernoulli) distribution is characterized as a vector , where represents the probability mass of the -th event from a discrete outcome space . Commonly, we take the index random variable to map the outcome space to the natural numbers, giving
As a valid distribution, it satisfies . We can express the expectation of the random variable as .
If the distribution represents, for example, the possible next words in a language model, computing the expected value may not be particularly meaningful, as it would reflect the average index of the next word.
Sometimes, we are interested in the expected value of a function of the outcomes , that is, . However, for brevity, we often write (using the law of the unconscious statistician).
Acknowledgements
Thanks @Clara Meister for guiding me through the literature and thanks @Tim Vieira for pointing my initial runtime complexity mistakes!