In this week’s paper, we read about speculative decoding, a landmark paper in transformer optimization, which has piqued my interest for quite a while.
Speculative decoding is mainly concerned with autoregressive models that suffer from sequential bottlenecks, where each new token distribution must be computed by running an entire large model one after another. This is compute-intensive and latency-sensitive. The authors proposed a new method to reduce compute by leveraging a much cheaper approximation model.
In essence, we let a small and fast model to speculate a batch of future tokens sequentially, then run the large model in parallel on all of those speculative prefixes. We then accept as the tokens with similar distribution as the original distribution of the large model, and only fall back on the large model on tokens that deviate the distribution by a significant margin.
The authors claimed that this approach recovers the exact same distribution as standard autoregressive decoding while yielding 2 - 3x wall-time speedups in practice, without the need to finetune or making intrusive changes to the large model.
In my (ongoing) effort to be more organized, I will break down this blog into the following parts:
- Explain in further details what is sequential bottleneck
- Discuss rejection sampling, the heart of the algorithm
- Show how algorithm 1 of the paper is derived
- Provide choices for the ‘cheap’ approximation model
- Summarize key results in the paper, mainly the speedups
- Write my concluding thoughts about the paper
What is sequential bottleneck?
For all autoregressive model (the ‘large’ model), we produce an output sequence by iteratively sampling
So at every time step , we must run a full forward pass through to compute the distribution over the vocabulary. The numbers would add up fast; if is 200 tokens, that’s 200 sequential calls into a multi-billion-parameter network, which is taxing in compute. There is also the latency problem, since each token is generated sequentially, we cannot count future tokens until the current token is finalized.
So maybe, if we can predict several next tokens in advance (i.e. you guess ) in one shot), we could potentially run in all of those prefixes in parallel, paying only one batch forward pass cost instead of sequential passes. This method will work as long as the predicted tokens are taken from the (close approximation of) original distribution.
This is the key of the problem. How do we ensure that the predicted tokens from the cheap model does not drift off the true distribution ? Speculative decoding solved this through rejection sampling, in which ‘good’ guesses are accepted, and ‘bad’ guesses are iteratively corrected and improved. This method ensures that the output remains exactly drawn from the true .
The intuition behind rejection sampling
Suppose we have:
- A target distribution of the large model on a discrete vocabulary
- A predicted distribution of the cheap, fast model on the same vocabulary
To effectively sample from , first draw . Then we accept with the probability of , where is a constant, otherwise reject, correct, and resample. In this speculative decoding algorithm:
- We set and only rejection when (i.e. when the small model overestimate the original model). Formalizing this, for each candidate , we accept with probability
- If we accept, we output . If we reject, we must pick the new from the leftover distribution
It is important to note that this two-step procedure produces exactly . Detailed proof of this is available in Appendix A1 of the paper.
Algorithm 1 in detail
With these two concepts in mind, let’s go through algorithm 1 step by step. For your reference, below is the algorithm as specified in the paper.
Let’s say:
- is a large and expensive autoregressive model
- is a small and cheap approximation model
- The current prefix (i.e. the tokens generated so far) is
- The speculation length is
The algorithm returns up to new tokens (depending on how many is rejected. All of these evaluations of is done in one parallel batch.
Step 1: Speculative sampling via
We first sample a sequence of tokens from autoregressively.
- Set
- For to :
- Compute
- Sample a guess
- At the end, we should have
This costs sequential forward passes for , but since is much cheaper than , we assume that this cost is “negligible” compared to one call of .
Step 2: Parallel evaluation by
Next, we need to know the for each speculative token to decide whether we accept this token or not.
- In parallel, run on the following contexts:
- Each of these context yield one distribution over
- Note that we assume running in parallel does not incur any additional costs as compared to running once.
Step 3: Determine acceptance with rejection sampling
We now compare each with the true distribution . For :
- Draw a uniform random number
- Compute the acceptance criterion:
- Accept speculated token iff . effectively serves as the threshold
Let
Basically is exactly the number of accepted tokens before the first failure. For example:
- If the first speculated token is rejected, then
- If all guesses are accepted, then
At this point, we know that are correct samples taken from the real distribution . Meanwhile, is rejected. We still need to determine how to produce a “true sample” for this token onwards.
Step 4: Fixing the distribution
In the case that there are some rejected tokens, we must re-draw the token from a rejection-corrected distribution . Let’s say the failure happens at token , as above.
- The true distribution at that position is
- We already know that , since that’s the precondition of the rejection (see step 3).
- Define the “leftover” vector of
- Normalize this leftover to get
Note the derivation of the denumerator can be done through careful observation.
- Sample
- Return the extended block
By construction, it is verified that we have recovered the same marginal for the -th token as if we had sampled it from . Therefore, speculative decoding reproduces the true next-token distirbution exactly while using cheap predictions from .
Designing
A good candidate must be (1) much faster per token than to maximize speedups, (2) has sufficient overlap with for expected per-step token acceptance rate, , to be reasonably high.
Options in the paper include:
- n-Gram language model
- With bigram or trigram, inference is effectively just table lookups, practically with zero-cost
- Tiny transformer models
- Take the same transformer model, but shrink the architecture to less layers with narrower hidden size
- Heuristic copy models
- If the prefix appears in the context, guess the next token with high probability
- This is also a zero-cost approximation model
- Non-autoregressive models
- Instead of running the approximation model sequentially, we can do only one forward pass.
Summary of key results
There are two main experiments conducted in the paper.
Empirical walltime improvement
- Tested on T5 models, on machine translation and summarization tasks
- Base model is T5-XXL with 11B parameters
- As the size of approximation model increases, increases
- Shown that T5-small is the best approximation model yielding 2 to 3x speedups, regardless of the decoding strategy (greedy or sampling)
- Other approximation models (T5-base and T5-large) also results in a ~1x to 2x speedups
Empirical values
- Tested on extra two models: GPT-like, with 97M parameters, and LaMDA with 137B parameters
- The extra models are tested on dialog and text-generation tasks, while the T5 models are tested on the same tasks as above
- Shown that tiny transformer models as the approximation model tend to perform best with values between 0.5 to 0.9
- Unigram and bigram models still stand to produce speedups. For example, in English to German translation, a bigram model results in 1.25x speedups to the original T5-XXL model
- However, the speedups from n-gram model is still lower than the speedups from tiny transformer models such as T5-small as the approximation model
My thoughts
I can clearly see why this paper is a landmark paper. It is very well-written, with strong theoretical foundation combined with clear empirical significance.
I can’t help but compare this method with knowledge distillation, as both are optimization methods leveraging a smaller model. It is quite interesting to note that knowledge distillation is more widely used (?), at least in my own experience with DeepSeek-R1 being distilled to LLaMA and Qwen models. Maybe there’s a way to combine both? What if we use distilled model as the approximation model? Would it improve the by much?
Another method I am curious to see implemented in the speculative decoding context is the escalating framework from this paper. What if we start with n-gram models, and then slowly increase to larger models when the smaller models are unable to produce good . Will this method results in worse performance than just using one approximation model? What would be the thereotical expected speedups using this escalation method? These are interesting avenues.
Other than that, great paper. Would love to implement the original vanilla speculative decoding for my work with Whisper soon.
