How LLMs got me into sampling theory

by Manuel de Prada Corral

6 min read

This WIP post collects some of my notes on sampling theory. My end goal is to have a global understanding of importance sampling and Horvitz-Thompson estimators.

Sampling from a probabilistic model can serve many purposes. The obvious one is to generate samples, such as images, text, or audio. However, we can also use sampling to compute expectations, such as the expected value of a function of the samples.

These notes were made with sampling from a Language Model in mind, but they are applicable to many autorregressive models. The key idea is that we cannot sample from the model directly, but we have to recursively sample the next word from the conditional distributions1.

y1โˆผp(y1)y2โˆผp(y2โˆฃy1)โ‹ฎyTโˆผp(yTโˆฃy1,โ€ฆ,yTโˆ’1) \begin{aligned} y_1 &\sim p(y_1) \\ y_2 &\sim p(y_2|y_1) \\ &\vdots \\ y_T &\sim p(y_T|y_1,\dots,y_{T-1}) \end{aligned}

Hurray! We got a sample y=(y1,โ€ฆ,yT)\mathbf{y}=(y_1,\dots,y_T) from the model! However, as it happens with LLMs, often the samples are not good enough, in the sense of what humans judge as "good text".

In practice, when sampling from LLMs, often one of the TT steps yields an unlikely word, ruining the whole sample. Having a good calibration for unlikely words is very difficult, and there are ad-hoc interventions such as sampling adaptors. Another approach is to avoid sampling altogether, and instead just deterministically search for the most likely sequence of words, which is called beam search2.

But, what if there was a principled way to get better samples?

TODO: Introduce utility function, minimum bayes risk, and importance sampling...

Sampling with replacement from the categorical distribution

In machine learning, we often fit a model to produce a vector of unnormalized log-probabilities ฯ•=(ฯ•1,โ€ฆ,ฯ•n)\mathbf{\phi}=(\phi_1,\dots,\phi_n), and we need to sample from the corresponding categorical distribution.

The naive approach

To sample from the categorical distribution, we can use the inverse transform sampling method:

  1. Normalize the log-probabilities (complexity O(n)O(n)).

pi=expโก(ฯ•i)โˆ‘jexpโก(ฯ•j) p_i = \frac{\exp(\phi_i)}{\sum_j \exp(\phi_j)}

  1. Compute the cumulative distribution function (CDF) as F(i)=โˆ‘j=1ipjF(i)=\sum_{j=1}^i p_j.

  2. Sample from the uniform distribution uโˆผU(0,1)u\sim\mathcal{U}(0,1).

  3. Finally, we pick the biggest index ii such that F(i)โ‰คuF(i)\leq u. (complexity O(log(n))O(log(n)) by binary search, since the CDF is sorted).

In total, this naive approach has complexity O(n+kโ‹…log(n))O(n + k\cdot log(n)), where k is the number of samples.

Proof of inverse transform sampling

We are interested in sampling from a random variable XX using a random variable UโˆผU(0,1)U\sim\mathcal{U}(0,1), so we need to find a function TT such that X=T(U)X=T(U). Now,

FX(x)=P(Xโ‰คx)=P(T(U)โ‰คx)=P(Uโ‰คTโˆ’1(x))=FU(Tโˆ’1(x))=Tโˆ’1(x) F_X(x) = \mathbb{P}(X\leq x) = \mathbb{P}(T(U)\leq x) = \mathbb{P}(U\leq T^{-1}(x)) = F_U(T^{-1}(x)) = T^{-1}(x)

Hence, T=FXโˆ’1T=F_X^{-1}, and we can sample from XX by sampling from UU and applying TT.

Numerical stability considerations When computing the normalizer in step 1, the sum might overflow. We can avoid this by subtracting the maximum value from all the log-probabilities, so that the maximum value is zero, i.e., ฯ•iโ€ฒ=ฯ•iโˆ’maxโกjฯ•j\phi_i' = \phi_i - \max_j \phi_j.

The Gumbel trick

The Gumbel trick allows to sample from the categorical distribution without computing the CDF.

  1. Sample nn independent Gumbel variables giโˆผG(0,1)g_i\sim\mathcal{G}(0,1). This can be easily done using the inverse transform sampling: gi=โˆ’logโก(โˆ’logโก(ui)),  uiโˆผU(0,1)g_i = -\log(-\log(u_i)),\ \ u_i\sim\mathcal{U}(0,1).

  2. Compute the perturbed log-probabilities ฯ•iโ€ฒ=ฯ•i+gi\phi_i' = \phi_i + g_i. Since the Gumbel distributions are a location-scale family, ฯ•iโ€ฒโˆผG(ฯ•i,1)\phi_i'\sim\mathcal{G}(\phi_i,1).

  3. Finally, we pick the biggest index ii such that ฯ•iโ€ฒโ‰ฅฯ•jโ€ฒ\phi_i'\geq \phi_j' for all jโ‰ ij\neq i. (complexity O(n)O(n) since they are not sorted). In other words, we take the index

argโกmaxโกiฯ•iโ€ฒ=argโกmaxโกiฯ•i+gi \arg\max_i \phi_i' = \arg\max_i \phi_i + g_i

Proof

Lemma 1: The inverse CDF of the exponential distribution Exp(ฮป)\text{Exp}(\lambda) is

Fโˆ’1(u)=โˆ’1ฮปlogโก(1โˆ’u). F^{-1}(u) = -\frac{1}{\lambda}\log(1-u).

Lemma 2: If X1โˆผExp(ฮป1),โ€ฆ,XnโˆผExp(ฮปn)X_1 \sim \text{Exp}(\lambda_1), \dots, X_n \sim \text{Exp}(\lambda_n) are independent, then minโกiXiโˆผExp(โˆ‘iฮปi)\min_i X_i \sim \text{Exp}(\sum_i \lambda_i) and P(Xi=minโกjXj)=ฮปiโˆ‘jฮปj.\mathbb{P}(X_i = \min_j X_j) = \frac{\lambda_i}{\sum_j \lambda_j}.

Observe that the probability of a tie is zero.

Proof. We want to prove that the probability of picking the index ii is pip_i, i.e., the probability of ฯ•iโ€ฒ\phi'_i being the biggest perturbed log-probability is pip_i.

First part: show that expโก(โˆ’ฯ•iโ€ฒ)โˆผExp(piฮฑ)\exp(-\phi'_i)\sim \text{Exp}(p_i\alpha):

Recall that ฯ•i=logโกpi+logโกฮฑ\phi_i=\log p_i +\log\alpha are the unnormalized log-probabilities, we can write ฯ•iโ€ฒ=ฯ•i+gi=logโกpi+logโกฮฑโˆ’logโก(โˆ’logโก(ui))=โˆ’logโก(1piฮฑโ‹…logโก(1ui)). \begin{aligned} \phi_i' &= \phi_i + g_i = \log p_i +\log \alpha - \log(-\log(u_i)) \\ &= -\log\left(\frac{1}{p_i\alpha} \cdot \log(\frac{1}{u_i})\right). \end{aligned}

Hence, expโก(โˆ’ฯ•iโ€ฒ)=1piฮฑโ‹…logโก(1ui)\exp(-\phi'_i) = \frac{1}{p_i\alpha} \cdot \log(\frac{1}{u_i}), which is the inverse CDF of the exponential distribution with parameter piฮฑp_i\alpha (Lemma 1).

Using inverse transform sampling, since uiโˆผU(0,1)u_i\sim\mathcal{U}(0,1), we have that expโก(โˆ’ฯ•iโ€ฒ)โˆผExp(piฮฑ)\exp(-\phi'_i)\sim \text{Exp}(p_i\alpha).

Second part: Note that argโกmaxโกiฯ•iโ€ฒ=argโกminโกiexpโก(โˆ’ฯ•iโ€ฒ)โˆผargโกminโกiExp(piฮฑ).\arg\max_i \phi_i' = \arg\min_i \exp(-\phi'_i) \sim \arg\min_i \text{Exp}(p_i\alpha).

Third part: Finally,

P(argโกmaxโกiฯ•iโ€ฒ=i)=P(argโกminโกiExp(piฮฑ)=i)=P(minโกjExp(pjฮฑ)=Exp(piฮฑ))=piฮฑโˆ‘jpjฮฑ=pi, \begin{aligned} \mathbb{P}(\arg\max_i \phi_i' = i) &= \mathbb{P}(\arg\min_i \text{Exp}(p_i\alpha) = i)\\ &= \mathbb{P}\left(\min_j \text{Exp}(p_j\alpha) = \text{Exp}(p_i\alpha)\right)\\ &= \frac{p_i\alpha}{\sum_j p_j\alpha} = p_i, \end{aligned}

where in the last step we have used Lemma 2.

The top-k Gumbel trick

We just saw how the Gumbel trick allows to sample from the categorical distribution, by computing

argโกmaxโกi(ฯ•iโˆ’logโก(โˆ’logโก(ui)),  uiโˆผU(0,1)). \arg\max_i \left(\phi_i - \log(-\log(u_i)),\ \ u_i\sim\mathcal{U}(0,1)\right).

TODO: how Maddison et al proved that max and argmax are independent, and that taking the top-k is equivalent to sampling withouth replacement.

Computing expectations

In order to compute expectations, if we don't have any domain-specific closed-form expression, we typically resort to Monte Carlo (MC) estimation. This involves sampling mm times from the model, and computing the average of the function of interest ff:

E[f(X)]โ‰ˆ1mโˆ‘i=1mf(xi),  xiโˆผp(x). \mathbb{E}[f(X)] \approx \frac{1}{m}\sum_{i=1}^m f(x_i),\ \ x_i\sim p(x).

The intuition is simple: in a discrete world, the most probable samples will be sampled more often, so the average will be close to the expectation.

The Monte Carlo estimator is unbiased, but it has high variance. To compensate, we would need more samples, which is often infeasible. Also, if the distribution has low entropy, we will be inneficiently sampling the same values over and over again.

Importance sampling

Importance sampling is a technique to reduce the variance of the Monte Carlo estimator. The idea is to sample from a different distribution q(x)q(x), and then reweight the samples by the ratio of the probabilities:

E[f(X)]=โˆ‘xf(x)p(x)=โˆ‘xf(x)p(x)q(x)q(x)=Eq[f(X)p(X)q(X)]. \mathbb{E}[f(X)] = \sum_x f(x) p(x) = \sum_x f(x) \frac{p(x)}{q(x)} q(x) = \mathbb{E}_q\left[f(X)\frac{p(X)}{q(X)}\right].

How sampling theory can help

TODO: stratified sampling, Horvitz-Thompson estimator, weighted reservoir sampling, priority sampling, sparse vector representations.

Bayes Risk and MBR

Sampling without replacement from a LLM

Appendix: a rigorous definition for the (p1,โ€ฆ,pn)\mathbf {(p_1,\dots,p_n)} notation for discrete distributions

A categorical (generalized Bernoulli) distribution is characterized as a vector p=(p1,โ€ฆ,pn)\mathbf{p}=(p_1,\dots,p_n), where pip_i represents the probability mass of the ii-th event from a discrete outcome space ฮฉ\Omega. Commonly, we take the index random variable I:ฮฉโ†’NI:\Omega\to\mathbb{N} to map the outcome space to the natural numbers, giving

P(I=i):=p(i)={piif iโˆˆ1,โ€ฆ,n0otherwise. \mathbb{P}(I = i) := p(i) = \begin{cases} p_i & \text{if } i \in {1,\dots,n} \\ 0 & \text{otherwise.} \end{cases}

As a valid distribution, it satisfies โˆ‘ipi=1 \sum_i p_i = 1 . We can express the expectation of the random variable I I as E[I]=โˆ‘iipi \mathbb{E}[I] = \sum_i i p_i .

If the distribution represents, for example, the possible next words in a language model, computing the expected value may not be particularly meaningful, as it would reflect the average index of the next word.

Sometimes, we are interested in the expected value of a function of the outcomes f:ฮฉโ†’Rf:\Omega\to\mathbb{R}, that is, EI[f(Iโˆ’1(i))] \mathbb{E}_I[f(I^{-1}(i))] . However, for brevity, we often write E[f(I)]=โˆ‘if(i)pi \mathbb{E}[f(I)] = \sum_i f(i) p_i (using the law of the unconscious statistician).

Acknowledgements

Thanks @Clara Meister for guiding me through the literature and thanks @Tim Vieira for pointing my initial runtime complexity mistakes!