skip to content
Site header image Nadya Yuki Wangsajaya

Fast Inference from Transformer via Speculative Decoding

Leviathan, Kalman, Matias

Last Updated:

In this week’s paper, we read about speculative decoding, a landmark paper in transformer optimization, which has piqued my interest for quite a while.

Speculative decoding is mainly concerned with autoregressive models that suffer from sequential bottlenecks, where each new token distribution must be computed by running an entire large model one after another. This is compute-intensive and latency-sensitive. The authors proposed a new method to reduce compute by leveraging a much cheaper approximation model.

In essence, we let a small and fast model to speculate a batch of future tokens sequentially, then run the large model in parallel on all of those speculative prefixes. We then accept as the tokens with similar distribution as the original distribution of the large model, and only fall back on the large model on tokens that deviate the distribution by a significant margin.

The authors claimed that this approach recovers the exact same distribution as standard autoregressive decoding while yielding 2 - 3x wall-time speedups in practice, without the need to finetune or making intrusive changes to the large model.

In my (ongoing) effort to be more organized, I will break down this blog into the following parts:


What is sequential bottleneck?

For all autoregressive model MpM_p (the ‘large’ model), we produce an output sequence y=(y1,y2,,yT)y = (y_1, y_2, …, y_T) by iteratively sampling

ytp(yty<t)wherep(y<t)=Mp(y<t)y_t \sim p(y_t|y_{<t}) \hspace{0.5cm} \text{where} \hspace{0.5cm} p(\cdot |y_{<t}) = M_p(y_{<t})

So at every time step tt, we must run a full forward pass through MpM_p to compute the distribution over the vocabulary. The numbers would add up fast; if TT is 200 tokens, that’s 200 sequential calls into a multi-billion-parameter network, which is taxing in compute. There is also the latency problem, since each token is generated sequentially, we cannot count future tokens until the current token is finalized.

So maybe, if we can predict several next tokens in advance (i.e. you guess (yt,yt+1,,yt+γ)(y_t, y_{t+1}, …, y_{t+\gamma})) in one shot), we could potentially run MpM_p in all of those prefixes in parallel, paying only one batch forward pass cost instead of γ\gamma sequential passes. This method will work as long as the predicted tokens are taken from the (close approximation of) original MpM_p distribution.

This is the key of the problem. How do we ensure that the predicted tokens from the cheap model MqM_q does not drift off the true distribution MpM_p ? Speculative decoding solved this through rejection sampling, in which ‘good’ guesses are accepted, and ‘bad’ guesses are iteratively corrected and improved. This method ensures that the output remains exactly drawn from the true MpM_p.


The intuition behind rejection sampling

Suppose we have:

  • A target distribution p()p(\cdot) of the large model MpM_p on a discrete vocabulary V\mathcal{V}
  • A predicted distribution q()q(\cdot) of the cheap, fast model MqM_q on the same vocabulary V\mathcal{V}

To effectively sample from pp, first draw xqx \sim q. Then we accept xx with the probability of p(x)M×q(x)\frac{p(x)}{M\times q(x)}, where MM is a constant, otherwise reject, correct, and resample. In this speculative decoding algorithm:

  1. We set M=1M = 1 and only rejection when q(x)>p(x)q(x) > p (x) (i.e. when the small model overestimate the original model). Formalizing this, for each candidate xqx \sim q, we accept with probability
    ρ(x)=min(1,p(x)q(x))\rho(x) = \min\left( 1, \frac{p(x)}{q(x)}\right)
  2. If we accept, we output xx. If we reject, we must pick the new xx from the leftover distribution
    r(x)=max(0,p(x)q(x))yVmax(0,p(y)q(y))r(x) = \frac{\max(0, p(x) - q(x))}{\sum_{y\in\mathcal{V}} \max(0, p(y) - q(y))}

    It is important to note that this two-step procedure produces exactly pp. Detailed proof of this is available in Appendix A1 of the paper.


Algorithm 1 in detail

With these two concepts in mind, let’s go through algorithm 1 step by step. For your reference, below is the algorithm as specified in the paper.

Let’s say:

  • MpM_p is a large and expensive autoregressive model
  • MqM_q is a small and cheap approximation model
  • The current prefix (i.e. the tokens generated so far) is y<ty_{<t}
  • The speculation length is γ1\gamma \geq 1

The algorithm returns up to γ\gamma new tokens (depending on how many is rejected. All of these evaluations of MpM_p is done in one parallel batch.

Step 1: Speculative sampling via MqM_q

We first sample a sequence of γ\gamma tokens from MqM_q autoregressively.

  1. Set x~=[  ]\tilde{\mathrm{x}} = [\;]
  2. For i=1i = 1 to γ\gamma:
    1. Compute qi()=Mq(y<t,x~1,...,x~i1)q_i(\cdot) = M_q(y_{<t}, \tilde{x}_1, ..., \tilde{x}_{i-1})
    2. Sample a guess x~iqi()\tilde{x}_i \sim q_i(\cdot)
  3. At the end, we should have x~=(x~1,...,x~γ)\tilde{\mathrm{x}} = (\tilde{x}_1, ..., \tilde{x}_\gamma)

This costs γ\gamma sequential forward passes for MqM_q, but since MqM_q is much cheaper than MpM_p, we assume that this cost is “negligible” compared to one call of MpM_p.

Step 2: Parallel evaluation by MpM_p

Next, we need to know the pi(x~i)p_i(\tilde{x}_i) for each speculative token x~i\tilde{x}_i to decide whether we accept this token or not.

  • In parallel, run MpM_p on the following (γ+1)(\gamma + 1) contexts:
    Context 1:y<t        p1(),Context 2:y<t,x~1        p2(),Context i:y<t,x~1,,x~i1        pi(), i=1,,γ,Context (γ+1):y<t,x~1,,x~γ         pγ+1().\begin{aligned}&\text{Context }1: \quad \mathbf{y}_{<t} \;\;\Longrightarrow\;\; p_{1}(\cdot),\\&\text{Context }2: \quad \mathbf{y}_{<t},\,\tilde{x}_{1} \;\;\Longrightarrow\;\; p_{2}(\cdot),\\&\quad\vdots\\&\text{Context }i: \quad \mathbf{y}_{<t},\,\tilde{x}_{1},\,\dots,\,\tilde{x}_{i-1} \;\;\Longrightarrow\;\; p_{i}(\cdot), \quad i = 1,\dots,\gamma,\\&\text{Context }(\gamma+1): \quad \mathbf{y}_{<t},\,\tilde{x}_{1},\,\dots,\,\tilde{x}_{\gamma} \;\;\Longrightarrow\;\; p_{\gamma+1}(\cdot).\end{aligned}
  • Each of these γ+1\gamma + 1 context yield one distribution pi()p_i(\cdot) over V\mathcal{V}
  • Note that we assume running MpM_p in parallel does not incur any additional costs as compared to running MpM_p once.

Step 3: Determine acceptance with rejection sampling

We now compare each x~i\tilde{x}_i with the true distribution pip_i. For i=1,,γi = 1, …, \gamma:

  1. Draw a uniform random number ri=Uniform(0,1)r_i = \text{Uniform}(0,1)
  2. Compute the acceptance criterion:
    ρi=pi(x~i)qi(x~i)\rho_i = \frac{p_i(\tilde{x}_i)}{q_i(\tilde{x}_i)}
  3. Accept speculated token x~i\tilde{x}_i iff rimin(1,ρi)r_i \leq \min(1, \rho_i). rir_i effectively serves as the threshold


Let

n=max{k{0,1,...,γ}:ik were accepted}n = \max \{k\in \{0, 1, ..., \gamma \}: \forall i \leq k \text{ were accepted}\}

Basically nn is exactly the number of accepted tokens before the first failure. For example:

  • If the first speculated token x~1\tilde{x}_1 is rejected, then n=0n = 0
  • If all γ\gamma guesses are accepted, then n=γn = \gamma

At this point, we know that x~1,,x~n\tilde{x}_1, …, \tilde{x}_n are correct samples taken from the real distribution p1,,pnp_1, …, p_n. Meanwhile, x~n+1\tilde{x}_{n+1} is rejected. We still need to determine how to produce a “true sample” for this token onwards.

Step 4: Fixing the distribution

In the case that there are some rejected tokens, we must re-draw the token from a rejection-corrected distribution p(x)p’(x). Let’s say the failure happens at token n+1n+1, as above.

  1. The true distribution at that position is pn+1()p_{n+1}(\cdot)
  2. We already know that qn+1(x~n+1)>pn+1(x~n+1)q_{n+1}(\tilde{x}_{n+1}) > p_{n+1}(\tilde{x}_{n+1}), since that’s the precondition of the rejection (see step 3).
  3. Define the “leftover” vector of V\mathcal{V}
    L(z)=max(0,pn+1(z)qn+1(z)),zVL(z) = \max(0, p_{n+1}(z) - q_{n+1}(z)), \hspace{0.5cm} \forall z \in \mathcal{V}
  4. Normalize this leftover to get
    p(z)=L(z)yVL(y)=max(0,pn+1(z)qn+1(z))1yVmin(pn+1(y),qn+1(y))\begin{aligned} p'(z) &= \frac{L(z)}{ \sum_{y \in V} L(y)}\\ &=\frac{ \max\bigl(0,p_{n+1}(z) - q_{n+1}(z)\bigr)}{ 1 - \sum_{y \in V} \min\bigl(p_{n+1}(y),q_{n+1}(y)\bigr)} \end{aligned}

    Note the derivation of the denumerator can be done through careful observation.

  5. Sample yt+np()y_{t+n} \sim p'(\cdot)
  6. Return the extended block (x~1,,x~n,yt+n)(\tilde{x}_1, …, \tilde{x}_n, y_{t+n})

By construction, it is verified that we have recovered the same marginal for the (n+1)(n+1)-th token as if we had sampled it from pn+1p_{n+1}. Therefore, speculative decoding reproduces the true next-token distirbution exactly while using cheap predictions from MqM_q.


Designing MqM_q

A good MqM_q candidate must be (1) much faster per token than MpM_p to maximize speedups, (2) has sufficient overlap with MpM_p for expected per-step token acceptance rate, α\alpha, to be reasonably high.

Options in the paper include:

  1. n-Gram language model
    • With bigram or trigram, inference is effectively just table lookups, practically with zero-cost
  2. Tiny transformer models
    • Take the same transformer model, but shrink the architecture to less layers with narrower hidden size
  3. Heuristic copy models
    • If the prefix appears in the context, guess the next token with high probability
    • This is also a zero-cost approximation model
  4. Non-autoregressive models
    • Instead of running the approximation model sequentially, we can do only one forward pass.

Summary of key results

There are two main experiments conducted in the paper.

Empirical walltime improvement

  • Tested on T5 models, on machine translation and summarization tasks
  • Base model is T5-XXL with 11B parameters
  • As the size of approximation model increases, α\alpha increases
  • Shown that T5-small is the best approximation model yielding 2 to 3x speedups, regardless of the decoding strategy (greedy or sampling)
  • Other approximation models (T5-base and T5-large) also results in a ~1x to 2x speedups

Empirical α\alpha values

  • Tested on extra two models: GPT-like, with 97M parameters, and LaMDA with 137B parameters
  • The extra models are tested on dialog and text-generation tasks, while the T5 models are tested on the same tasks as above
  • Shown that tiny transformer models as the approximation model tend to perform best with α\alpha values between 0.5 to 0.9
  • Unigram and bigram models still stand to produce speedups. For example, in English to German translation, a bigram model results in 1.25x speedups to the original T5-XXL model
    • However, the speedups from n-gram model is still lower than the speedups from tiny transformer models such as T5-small as the approximation model

My thoughts

I can clearly see why this paper is a landmark paper. It is very well-written, with strong theoretical foundation combined with clear empirical significance.

I can’t help but compare this method with knowledge distillation, as both are optimization methods leveraging a smaller model. It is quite interesting to note that knowledge distillation is more widely used (?), at least in my own experience with DeepSeek-R1 being distilled to LLaMA and Qwen models. Maybe there’s a way to combine both? What if we use distilled model as the approximation model? Would it improve the α\alpha by much?

Another method I am curious to see implemented in the speculative decoding context is the escalating framework from this paper. What if we start with n-gram models, and then slowly increase to larger models when the smaller models are unable to produce good α\alpha. Will this method results in worse performance than just using one approximation model? What would be the thereotical expected speedups using this escalation method? These are interesting avenues.

Other than that, great paper. Would love to implement the original vanilla speculative decoding for my work with Whisper soon.