Unofficial documentation for the HuggingFace🤗 generation pipeline
by Manuel de Prada Corral
6 min read
While implementing a new generation strategy for Transformer models, I found myself delving deep into the HuggingFace library. The documentation is clear with respect to the usage, but not so much with respect to the implementation details.
Here is a collection of notes I've compiled from my dive into the codebase. This may prove beneficial for anyone looking to understand or extend HuggingFace's generation pipeline.
The overall class structure
HuggingFace Transformer models have all one common ancestor: PreTrainedModel
. This class is defined in transformers/modeling_utils.py
. It is a subclass of torch.nn.Module
, ModuleUtilsMixin
, GenerationMixin
and PushToHubMixin
.
graph TD;
ModuleUtilsMixin-->PreTrainedModel;
GenerationMixin-->PreTrainedModel;
PushToHubMixin-->PreTrainedModel;
torch.Module-->PreTrainedModel;
PretrainedMyModel-->MyModelForConditionalGeneration;
PreTrainedModel-->PretrainedMyModel;
The generation pipeline for all Transformer models is centralized in GenerationMixin
. This class is defined in transformers/generation/utils.py
, and all models must implement prepare_inputs_for_generation
. Additionally, models can implement adjust_logits_during_generation
and _reorder_cache
.
The main method in GenerationMixin
is generate
, which orchestrates the generation process and then calls the different specialized methods such as contrastive_search
, greedy_search
, sample
, beam_search
, beam_sample
, group_beam_search
, constrained_beam_search
and assisted_decoding
.
The generation pipeline
Let's break down the generation pipeline into its different steps. Note that these steps are written with the same numbers into the code comments.This is a permalink to the generate method being analyzed in this post (note that HF is a fast moving target, so some details may be outdated soon).
Another vital point to note is that the generation happens in batches, meaning that the input_ids have a shape of (batch_size, seq_len, embed_dim)
. This is to allow, for example, to translate multiple sentences at once.
%%{init: { 'themeVariables': {'fontSize': '24px'} } }%%
timeline
1. Prepare generation_config : Merge model and users gen config
2. Set generation parameters" : Prepare logits processors and stopping criteria
3. Define model_inputs : Get encoder inputs if needed
4. Define other model kwargs
5. Prepare input_ids which for the decoder : Initialize with <bos> if needed
%%{init: { 'themeVariables': {'fontSize': '23px'} } }%%
timeline
6. Prepare `max_length` depending on stopping criteria
7. Determine generation mode : Set is_greedy, is_sample, is_beam, ... : check if arguments are consistent
8. Prepare distribution pre_processing samplers : Prepare logits_processor
9. Prepare stopping criteria
10. Go into different generation modes
The logits_processor
is a list of functions that are applied to the logits before selecting or sampling the next token. There is also a logits_warper
that is applied to the logits after the logits_processor
but only in stochastic generation modes (sample
, beam_sample
, assisted_decoding
, constraint_beam_search
and contrastive_search
). Also, in beam_sample
mode, logits_processor
is applied to the logits, but then the logits are integrated into the beam search scores, and the logits_warper
is applied to the beam search scores.
The beam_search generation mode
timeline
11. Prepare beam search scorer : initialize beam hypotheses
12. Interleave input_ids with n_beams additional sequences : tensor of shape [batch_size, seq_len, embed_dim] -> [batch_size*n_beams, seq_len, embed_dim]
13. Run beam search : call beam_search method
The beam search generation mode has two main components:
- The
beam_search
method, found inGenerationMixin
, handles the primary decoding loop, maintains the beam scores and calls the model (referenced in step 13 ofgenerate
). - In
transformers/generation/beam_search.py
, BeamSearchScorer has one BeamHypotheses object for each sequence in the batch. It is a general construction that makes sense for generalizing beam search to diverse_beam_search (keep different groups of beams to ensure diversity).- The BeamHypotheses keeps the list of the
n_beams
best hypotheses for each sequence in the batch, with its beam scores and beam indices.
- The BeamHypotheses keeps the list of the
The beam_search method
- Initialize the beam_scores to 0 as a tensor of dimension
(batch_size, n_beams)
. - Set beam_scores to for all beams except the first one (
beam_scores[:,1:] = -1e9
). - View beam_scores as a 1D tensor of dimension
(batch_size*n_beams)
. - Generation loop:
- Run the model, get outputs for the next token over all beams of all sequences in the batch.
- Locally normalize the output (apply log_softmax).
- Apply the
logits_processor
to the logits. - Append the new logits to the running beam scores. Note that now we have a tensor of dimension
(batch_size*n_beams, vocab_size)
. - To form the next_token_scores, view as a tensor of dimension
(batch_size, n_beams*vocab_size)
. - Get the
2*n_beams
best scores from next_token_scores by applyingtorch.topk
. Derive the beam indices and token indices. - Call
beam_scorer.process
to update the beam hypotheses. Get the new beam scores, indices and next_tokens for each beam. Updateinput_ids
with the new tokens. - If all beams are finished or the stopping criteria are met, break the loop.
The BeamScorer process method
This method is defined in transformers/generation/beam_search.py
and takes as output the 2*n_beams
topk elements and indexes calculated above. The beam search scorer is initialized with a BeamHypotheses
object for each sequence in the batch.
- Create new tensors for the next scores, tokens and indices of dimension
(batch_size, group_size)
(this is because of diverse beam search, we knowgroup_size
=n_beams
for normal beam search. In this case, the tensors have dimension(batch_size, n_beams)
). - For each beam hypotheses object in the scorer (i.e. for each sentence in the batch):
- If the sentence is finished, do nothing and continue to the next sentence.
- For each (token, score, index) in the top
2*n_beams
next scores among then_beams*vocab_size
scores:- If the token is the EOS token, check if the beam is still among the
n_beams
best beams. If so, add the beam to the list of hypotheses of the sentence. The beam_score for this beam would be 0, since it moves from the running beams to the finished beams. - If the token is not the EOS token, add the token, score and beam_index to the next scores, tokens and indices tensors. If we have already all the running beams, break the loop (remember that we started from the top scores, so we only want to keep the
n_beams
best finished beams and then_beams
best running beams).
- If the token is the EOS token, check if the beam is still among the
We can see how the beam_hypotheses keep the n_beams
best finished beams, while the n_beams
best running beams are kept in the next_scores
, next_tokens
and next_indices
tensors, which are sent back and forth between the beam_search
method and the process
method, as the main loop from the beam_search
progresses through the running beams.
The interesting (and obscure) bits
Why do we need to select the 2*n_beams
best beams? It is something strange at first look. From a theoretical point of view, each new generation step will always make the sequence probabilities smaller, so the first n_beams
that reach <EOS>
will always be higher probability than any possible continuation. However, there is two empirical reasons to keep more beams alive.
First, in closed-vocabulary models, we might encounter that <UNK>
is the best token at some point. Most beam search implementations will fall back to the next best token in this case, hence needing n_beams+1
tokens. Second, beam search is commonly used with length normalization, which allows longer sequences to have a higher probability as they grow longer. This means that we need to store separately the best finished beams and the best running beams, and only make the comparison between them when they are finished (thanks Clara for helping me figure this out!).
This is why HF's beam_search
saves 2*n_beams
beams. We might encounter situations where all the alive n_beams
sequences reach <EOS>
, leaving no live sequences to continue. With 2*n_beams
, we are guaranteed to have at least one non-EOS
token for each beam hypothesis.
On top of this, without length normalization, we can stop generation when n_beams
sequences reach <EOS>
. This is achieved in HF by setting early_stopping=True
. When early_stopping
is set to False
or "never"
, HF will use two different non-satisfactory heuristics to stop generation whenever the best running beam is thought to be worse than the worst finished beam. Surprisingly, no setting of early_stopping
will effectively stop early stopping and let the generation continue until all beams are finished or the maximum length is reached. To be fair, this would probably cause OOM problems.
Interestingly, the beam search in HuggingFace was adapted from facebookresearch/XLM. You can check out the original 2019 commit here. Early days when Thomas Wolf was coding and HuggingFace was still a chatbot for teenagers!
The different scores and how to interpret them
During beam search, we keep track of the following scores:
beam_scores
: The running scores of the beams. This is the sum of the log probabilities of the tokens generated so far for each beam. It is a tensor of dimension(batch_size * n_beams)
. They model logits may have been modified by the logit processors or by the length penalty. Optionally, also:scores
: The word-per-word scores of the beams, this is, the log probabilities for every token in the vocabulary at each generation step. It is a tuple of sizeseq_len
of tensors of dimension(batch_size * n_beams, vocab_size)
. Beam indices are needed to recover the scores for each selected token.beam_indices
: The indices of the beams that generated the scores at each time step. I believe here beam_indices are referred to the indices of then_beams * vocab_size
scores of the previous timesteptorch.topk
call. However, I am not sure, and the indices may maintain coherence across timesteps. TODO: investigate this.