How LLMs got me into sampling theory

by Manuel de Prada Corral

6 min read

This WIP post collects some of my notes on sampling theory. My end goal is to have a global understanding of importance sampling and Horvitz-Thompson estimators.

Sampling from a probabilistic model can serve many purposes. The obvious one is to generate samples, such as images, text, or audio. However, we can also use sampling to compute expectations, such as the expected value of a function of the samples.

These notes were made with sampling from a Language Model in mind, but they are applicable to many autorregressive models. The key idea is that we cannot sample from the model directly, but we have to recursively sample the next word from the conditional distributions1.

y1∼p(y1)y2∼p(y2∣y1)⋮yT∼p(yT∣y1,…,yT−1) \begin{aligned} y_1 &\sim p(y_1) \\ y_2 &\sim p(y_2|y_1) \\ &\vdots \\ y_T &\sim p(y_T|y_1,\dots,y_{T-1}) \end{aligned}

Hurray! We got a sample y=(y1,…,yT)\mathbf{y}=(y_1,\dots,y_T) from the model! However, as it happens with LLMs, often the samples are not good enough, in the sense of what humans judge as "good text".

In practice, when sampling from LLMs, often one of the TT steps yields an unlikely word, ruining the whole sample. Having a good calibration for unlikely words is very difficult, and there are ad-hoc interventions such as sampling adaptors. Another approach is to avoid sampling altogether, and instead just deterministically search for the most likely sequence of words, which is called beam search2.

But, what if there was a principled way to get better samples?

TODO: Introduce utility function, minimum bayes risk, and importance sampling...

Sampling with replacement from the categorical distribution

In machine learning, we often fit a model to produce a vector of unnormalized log-probabilities ϕ=(ϕ1,…,ϕn)\mathbf{\phi}=(\phi_1,\dots,\phi_n), and we need to sample from the corresponding categorical distribution.

The naive approach

To sample from the categorical distribution, we can use the inverse transform sampling method:

  1. Normalize the log-probabilities (complexity O(n)O(n)).

pi=exp⁡(ϕi)∑jexp⁡(ϕj) p_i = \frac{\exp(\phi_i)}{\sum_j \exp(\phi_j)}

  1. Compute the cumulative distribution function (CDF) as F(i)=∑j=1ipjF(i)=\sum_{j=1}^i p_j.

  2. Sample from the uniform distribution u∼U(0,1)u\sim\mathcal{U}(0,1).

  3. Finally, we pick the biggest index ii such that F(i)≤uF(i)\leq u. (complexity O(log(n))O(log(n)) by binary search, since the CDF is sorted).

In total, this naive approach has complexity O(n+k⋅log(n))O(n + k\cdot log(n)), where k is the number of samples.

Proof of inverse transform sampling

We are interested in sampling from a random variable XX using a random variable U∼U(0,1)U\sim\mathcal{U}(0,1), so we need to find a function TT such that X=T(U)X=T(U). Now,

FX(x)=P(X≤x)=P(T(U)≤x)=P(U≤T−1(x))=FU(T−1(x))=T−1(x) F_X(x) = \mathbb{P}(X\leq x) = \mathbb{P}(T(U)\leq x) = \mathbb{P}(U\leq T^{-1}(x)) = F_U(T^{-1}(x)) = T^{-1}(x)

Hence, T=FX−1T=F_X^{-1}, and we can sample from XX by sampling from UU and applying TT.

Numerical stability considerations When computing the normalizer in step 1, the sum might overflow. We can avoid this by subtracting the maximum value from all the log-probabilities, so that the maximum value is zero, i.e., ϕi′=ϕi−max⁡jϕj\phi_i' = \phi_i - \max_j \phi_j.

The Gumbel trick

The Gumbel trick allows to sample from the categorical distribution without computing the CDF.

  1. Sample nn independent Gumbel variables gi∼G(0,1)g_i\sim\mathcal{G}(0,1). This can be easily done using the inverse transform sampling: gi=−log⁡(−log⁡(ui)),  ui∼U(0,1)g_i = -\log(-\log(u_i)),\ \ u_i\sim\mathcal{U}(0,1).

  2. Compute the perturbed log-probabilities ϕi′=ϕi+gi\phi_i' = \phi_i + g_i. Since the Gumbel distributions are a location-scale family, ϕi′∼G(ϕi,1)\phi_i'\sim\mathcal{G}(\phi_i,1).

  3. Finally, we pick the biggest index ii such that ϕi′≥ϕj′\phi_i'\geq \phi_j' for all j≠ij\neq i. (complexity O(n)O(n) since they are not sorted). In other words, we take the index

arg⁡max⁡iϕi′=arg⁡max⁡iϕi+gi \arg\max_i \phi_i' = \arg\max_i \phi_i + g_i

Proof

Lemma 1: The inverse CDF of the exponential distribution Exp(λ)\text{Exp}(\lambda) is

F−1(u)=−1λlog⁡(1−u). F^{-1}(u) = -\frac{1}{\lambda}\log(1-u).

Lemma 2: If X1∼Exp(λ1),…,Xn∼Exp(λn)X_1 \sim \text{Exp}(\lambda_1), \dots, X_n \sim \text{Exp}(\lambda_n) are independent, then min⁡iXi∼Exp(∑iλi)\min_i X_i \sim \text{Exp}(\sum_i \lambda_i) and P(Xi=min⁡jXj)=λi∑jλj.\mathbb{P}(X_i = \min_j X_j) = \frac{\lambda_i}{\sum_j \lambda_j}.

Observe that the probability of a tie is zero.

Proof. We want to prove that the probability of picking the index ii is pip_i, i.e., the probability of ϕi′\phi'_i being the biggest perturbed log-probability is pip_i.

First part: show that exp⁡(−ϕi′)∼Exp(piα)\exp(-\phi'_i)\sim \text{Exp}(p_i\alpha):

Recall that ϕi=log⁡pi+log⁡α\phi_i=\log p_i +\log\alpha are the unnormalized log-probabilities, we can write ϕi′=ϕi+gi=log⁡pi+log⁡α−log⁡(−log⁡(ui))=−log⁡(1piα⋅log⁡(1ui)). \begin{aligned} \phi_i' &= \phi_i + g_i = \log p_i +\log \alpha - \log(-\log(u_i)) \\ &= -\log\left(\frac{1}{p_i\alpha} \cdot \log(\frac{1}{u_i})\right). \end{aligned}

Hence, exp⁡(−ϕi′)=1piα⋅log⁡(1ui)\exp(-\phi'_i) = \frac{1}{p_i\alpha} \cdot \log(\frac{1}{u_i}), which is the inverse CDF of the exponential distribution with parameter piαp_i\alpha (Lemma 1).

Using inverse transform sampling, since ui∼U(0,1)u_i\sim\mathcal{U}(0,1), we have that exp⁡(−ϕi′)∼Exp(piα)\exp(-\phi'_i)\sim \text{Exp}(p_i\alpha).

Second part: Note that arg⁡max⁡iϕi′=arg⁡min⁡iexp⁡(−ϕi′)∼arg⁡min⁡iExp(piα).\arg\max_i \phi_i' = \arg\min_i \exp(-\phi'_i) \sim \arg\min_i \text{Exp}(p_i\alpha).

Third part: Finally,

P(arg⁡max⁡iϕi′=i)=P(arg⁡min⁡iExp(piα)=i)=P(min⁡jExp(pjα)=Exp(piα))=piα∑jpjα=pi, \begin{aligned} \mathbb{P}(\arg\max_i \phi_i' = i) &= \mathbb{P}(\arg\min_i \text{Exp}(p_i\alpha) = i)\\ &= \mathbb{P}\left(\min_j \text{Exp}(p_j\alpha) = \text{Exp}(p_i\alpha)\right)\\ &= \frac{p_i\alpha}{\sum_j p_j\alpha} = p_i, \end{aligned}

where in the last step we have used Lemma 2.

The top-k Gumbel trick

We just saw how the Gumbel trick allows to sample from the categorical distribution, by computing

arg⁡max⁡i(ϕi−log⁡(−log⁡(ui)),  ui∼U(0,1)). \arg\max_i \left(\phi_i - \log(-\log(u_i)),\ \ u_i\sim\mathcal{U}(0,1)\right).

TODO: how Maddison et al proved that max and argmax are independent, and that taking the top-k is equivalent to sampling withouth replacement.

Computing expectations

In order to compute expectations, if we don't have any domain-specific closed-form expression, we typically resort to Monte Carlo (MC) estimation. This involves sampling mm times from the model, and computing the average of the function of interest ff:

E[f(X)]≈1m∑i=1mf(xi),  xi∼p(x). \mathbb{E}[f(X)] \approx \frac{1}{m}\sum_{i=1}^m f(x_i),\ \ x_i\sim p(x).

The intuition is simple: in a discrete world, the most probable samples will be sampled more often, so the average will be close to the expectation.

The Monte Carlo estimator is unbiased, but it has high variance. To compensate, we would need more samples, which is often infeasible. Also, if the distribution has low entropy, we will be inneficiently sampling the same values over and over again.

Importance sampling

Importance sampling is a technique to reduce the variance of the Monte Carlo estimator. The idea is to sample from a different distribution q(x)q(x), and then reweight the samples by the ratio of the probabilities:

E[f(X)]=∑xf(x)p(x)=∑xf(x)p(x)q(x)q(x)=Eq[f(X)p(X)q(X)]. \mathbb{E}[f(X)] = \sum_x f(x) p(x) = \sum_x f(x) \frac{p(x)}{q(x)} q(x) = \mathbb{E}_q\left[f(X)\frac{p(X)}{q(X)}\right].

How sampling theory can help

TODO: stratified sampling, Horvitz-Thompson estimator, weighted reservoir sampling, priority sampling, sparse vector representations.

Bayes Risk and MBR

Sampling without replacement from a LLM

Appendix: a rigorous definition for the (p1,…,pn)\mathbf {(p_1,\dots,p_n)} notation for discrete distributions

A categorical (generalized Bernoulli) distribution is characterized as a vector p=(p1,…,pn)\mathbf{p}=(p_1,\dots,p_n), where pip_i represents the probability mass of the ii-th event from a discrete outcome space Ω\Omega. Commonly, we take the index random variable I:Ω→NI:\Omega\to\mathbb{N} to map the outcome space to the natural numbers, giving

P(I=i):=p(i)={piif i∈1,…,n0otherwise. \mathbb{P}(I = i) := p(i) = \begin{cases} p_i & \text{if } i \in {1,\dots,n} \\ 0 & \text{otherwise.} \end{cases}

As a valid distribution, it satisfies ∑ipi=1 \sum_i p_i = 1 . We can express the expectation of the random variable I I as E[I]=∑iipi \mathbb{E}[I] = \sum_i i p_i .

If the distribution represents, for example, the possible next words in a language model, computing the expected value may not be particularly meaningful, as it would reflect the average index of the next word.

Sometimes, we are interested in the expected value of a function of the outcomes f:Ω→Rf:\Omega\to\mathbb{R}, that is, EI[f(I−1(i))] \mathbb{E}_I[f(I^{-1}(i))] . However, for brevity, we often write E[f(I)]=∑if(i)pi \mathbb{E}[f(I)] = \sum_i f(i) p_i (using the law of the unconscious statistician).

Acknowledgements

Thanks @Clara Meister for guiding me through the literature and thanks @Tim Vieira for pointing my initial runtime complexity mistakes!