How LLMs got me into sampling theory
by Manuel de Prada Corral
6 min read
This WIP post collects some of my notes on sampling theory. My end goal is to have a global understanding of importance sampling and HorvitzThompson estimators.
Sampling from a probabilistic model can serve many purposes. The obvious one is to generate samples, such as images, text, or audio. However, we can also use sampling to compute expectations, such as the expected value of a function of the samples.
These notes were made with sampling from a Language Model in mind, but they are applicable to many autorregressive models. The key idea is that we cannot sample from the model directly, but we have to recursively sample the next word from the conditional distributions^{1}.
$\begin{aligned} y_1 &\sim p(y_1) \\ y_2 &\sim p(y_2y_1) \\ &\vdots \\ y_T &\sim p(y_Ty_1,\dots,y_{T1}) \end{aligned}$
Hurray! We got a sample $\mathbf{y}=(y_1,\dots,y_T)$ from the model! However, as it happens with LLMs, often the samples are not good enough, in the sense of what humans judge as "good text".
In practice, when sampling from LLMs, often one of the $T$ steps yields an unlikely word, ruining the whole sample. Having a good calibration for unlikely words is very difficult, and there are adhoc interventions such as sampling adaptors. Another approach is to avoid sampling altogether, and instead just deterministically search for the most likely sequence of words, which is called beam search^{2}.
But, what if there was a principled way to get better samples?
TODO: Introduce utility function, minimum bayes risk, and importance sampling...
Sampling with replacement from the categorical distribution
In machine learning, we often fit a model to produce a vector of unnormalized logprobabilities $\mathbf{\phi}=(\phi_1,\dots,\phi_n)$, and we need to sample from the corresponding categorical distribution.
The naive approach
To sample from the categorical distribution, we can use the inverse transform sampling method:
 Normalize the logprobabilities (complexity $O(n)$).
$p_i = \frac{\exp(\phi_i)}{\sum_j \exp(\phi_j)}$

Compute the cumulative distribution function (CDF) as $F(i)=\sum_{j=1}^i p_j$.

Sample from the uniform distribution $u\sim\mathcal{U}(0,1)$.

Finally, we pick the biggest index $i$ such that $F(i)\leq u$. (complexity $O(log(n))$ by binary search, since the CDF is sorted).
In total, this naive approach has complexity $O(n + k\cdot log(n))$, where k is the number of samples.
Proof of inverse transform sampling
We are interested in sampling from a random variable $X$ using a random variable $U\sim\mathcal{U}(0,1)$, so we need to find a function $T$ such that $X=T(U)$. Now,
$F_X(x) = \mathbb{P}(X\leq x) = \mathbb{P}(T(U)\leq x) = \mathbb{P}(U\leq T^{1}(x)) = F_U(T^{1}(x)) = T^{1}(x)$
Hence, $T=F_X^{1}$, and we can sample from $X$ by sampling from $U$ and applying $T$.
Numerical stability considerations
When computing the normalizer in step 1, the sum might overflow. We can avoid this by subtracting the maximum value from all the logprobabilities, so that the maximum value is zero, i.e., $\phi_i' = \phi_i  \max_j \phi_j$.The Gumbel trick
The Gumbel trick allows to sample from the categorical distribution without computing the CDF.

Sample $n$ independent Gumbel variables $g_i\sim\mathcal{G}(0,1)$. This can be easily done using the inverse transform sampling: $g_i = \log(\log(u_i)),\ \ u_i\sim\mathcal{U}(0,1)$.

Compute the perturbed logprobabilities $\phi_i' = \phi_i + g_i$. Since the Gumbel distributions are a locationscale family, $\phi_i'\sim\mathcal{G}(\phi_i,1)$.

Finally, we pick the biggest index $i$ such that $\phi_i'\geq \phi_j'$ for all $j\neq i$. (complexity $O(n)$ since they are not sorted). In other words, we take the index
$\arg\max_i \phi_i' = \arg\max_i \phi_i + g_i$
Proof
Lemma 1: The inverse CDF of the exponential distribution $\text{Exp}(\lambda)$ is
$F^{1}(u) = \frac{1}{\lambda}\log(1u).$
Lemma 2: If $X_1 \sim \text{Exp}(\lambda_1), \dots, X_n \sim \text{Exp}(\lambda_n)$ are independent, then $\min_i X_i \sim \text{Exp}(\sum_i \lambda_i)$ and $\mathbb{P}(X_i = \min_j X_j) = \frac{\lambda_i}{\sum_j \lambda_j}.$
Observe that the probability of a tie is zero.
Proof. We want to prove that the probability of picking the index $i$ is $p_i$, i.e., the probability of $\phi'_i$ being the biggest perturbed logprobability is $p_i$.
First part: show that $\exp(\phi'_i)\sim \text{Exp}(p_i\alpha)$:
Recall that $\phi_i=\log p_i +\log\alpha$ are the unnormalized logprobabilities, we can write $\begin{aligned} \phi_i' &= \phi_i + g_i = \log p_i +\log \alpha  \log(\log(u_i)) \\ &= \log\left(\frac{1}{p_i\alpha} \cdot \log(\frac{1}{u_i})\right). \end{aligned}$
Hence, $\exp(\phi'_i) = \frac{1}{p_i\alpha} \cdot \log(\frac{1}{u_i})$, which is the inverse CDF of the exponential distribution with parameter $p_i\alpha$ (Lemma 1).
Using inverse transform sampling, since $u_i\sim\mathcal{U}(0,1)$, we have that $\exp(\phi'_i)\sim \text{Exp}(p_i\alpha)$.
Second part: Note that $\arg\max_i \phi_i' = \arg\min_i \exp(\phi'_i) \sim \arg\min_i \text{Exp}(p_i\alpha).$
Third part: Finally,
$\begin{aligned} \mathbb{P}(\arg\max_i \phi_i' = i) &= \mathbb{P}(\arg\min_i \text{Exp}(p_i\alpha) = i)\\ &= \mathbb{P}\left(\min_j \text{Exp}(p_j\alpha) = \text{Exp}(p_i\alpha)\right)\\ &= \frac{p_i\alpha}{\sum_j p_j\alpha} = p_i, \end{aligned}$
where in the last step we have used Lemma 2.
The topk Gumbel trick
We just saw how the Gumbel trick allows to sample from the categorical distribution, by computing
$\arg\max_i \left(\phi_i  \log(\log(u_i)),\ \ u_i\sim\mathcal{U}(0,1)\right).$
TODO: how Maddison et al proved that max and argmax are independent, and that taking the topk is equivalent to sampling withouth replacement.
Computing expectations
In order to compute expectations, if we don't have any domainspecific closedform expression, we typically resort to Monte Carlo (MC) estimation. This involves sampling $m$ times from the model, and computing the average of the function of interest $f$:
$\mathbb{E}[f(X)] \approx \frac{1}{m}\sum_{i=1}^m f(x_i),\ \ x_i\sim p(x).$
The intuition is simple: in a discrete world, the most probable samples will be sampled more often, so the average will be close to the expectation.
The Monte Carlo estimator is unbiased, but it has high variance. To compensate, we would need more samples, which is often infeasible. Also, if the distribution has low entropy, we will be inneficiently sampling the same values over and over again.
Importance sampling
Importance sampling is a technique to reduce the variance of the Monte Carlo estimator. The idea is to sample from a different distribution $q(x)$, and then reweight the samples by the ratio of the probabilities:
$\mathbb{E}[f(X)] = \sum_x f(x) p(x) = \sum_x f(x) \frac{p(x)}{q(x)} q(x) = \mathbb{E}_q\left[f(X)\frac{p(X)}{q(X)}\right].$
How sampling theory can help
TODO: stratified sampling, HorvitzThompson estimator, weighted reservoir sampling, priority sampling, sparse vector representations.
Bayes Risk and MBR
Sampling without replacement from a LLM
Appendix: a rigorous definition for the $\mathbf {(p_1,\dots,p_n)}$ notation for discrete distributions
A categorical (generalized Bernoulli) distribution is characterized as a vector $\mathbf{p}=(p_1,\dots,p_n)$, where $p_i$ represents the probability mass of the $i$th event from a discrete outcome space $\Omega$. Commonly, we take the index random variable $I:\Omega\to\mathbb{N}$ to map the outcome space to the natural numbers, giving
$\mathbb{P}(I = i) := p(i) = \begin{cases} p_i & \text{if } i \in {1,\dots,n} \\ 0 & \text{otherwise.} \end{cases}$
As a valid distribution, it satisfies $\sum_i p_i = 1$. We can express the expectation of the random variable $I$ as $\mathbb{E}[I] = \sum_i i p_i$.
If the distribution represents, for example, the possible next words in a language model, computing the expected value may not be particularly meaningful, as it would reflect the average index of the next word.
Sometimes, we are interested in the expected value of a function of the outcomes $f:\Omega\to\mathbb{R}$, that is, $\mathbb{E}_I[f(I^{1}(i))]$. However, for brevity, we often write $\mathbb{E}[f(I)] = \sum_i f(i) p_i$ (using the law of the unconscious statistician).
Acknowledgements
Thanks @Clara Meister for guiding me through the literature and thanks @Tim Vieira for pointing my initial runtime complexity mistakes!