Speculative Sampling Explained

Speculative Sampling

The idea of speculative sampling is to use a draft sampling to achieve the same sampling result as the target sampling.

We have a target sampling distribution $p(x)$ and a draft sampling distribution $q(x)$. For each token $x_i$, we have two probabilities: $p(x_i)$ and $q(x_i)$, and we have:

  • either $p(x_i) > q(x_i)$,
  • or $p(x_i) \leq q(x_i)$,

If we directly sample from $q(x)$, we will get a sample $x$ that is not from the target distribution $p(x)$, and it is

  • over-sampled if $q(x_i) > p(x_i)$,
  • under-sampled if $q(x_i) < p(x_i)$.

The core trick of speculative sampling is to design a smart rejection method to down-sample the over-sampled tokens and up-sample the under-sampled tokens. This way we can achieve the same sampling result as the target sampling.

How to down-sample the over-sampled tokens?

The idea is to use a rejection method to reject the over-sampled probability. We only accept the sample with a probability of $p(x_i) / q(x_i)$. As our assumption was over-sampled, we have $p(x_i) < q(x_i)$, and the acceptance probability is less than 1. (This is also why they have $\min(1, p(x_i) / q(x_i))$ in the paper. The 1 is to deal with the case where $p(x_i) > q(x_i)$, which is not over-sampled.)

How to up-sample the under-sampled tokens?

The idea is to give the under-sampled probability an additional chance to be sampled. First, we do not reject the under-sampled probability. Second, when an rejection is triggered from the over-sampled probability, we re-sample from the under-sampled probability.

Here we first define a new distribution which represents the under-sampled probability. We call it the residual distribution.

The residual distribution is defined as follows: $ r(x_i) = \frac{\max(0, p(x_i) - q(x_i))}{\sum_{x_i} \max(0, p(x_i) - q(x_i))} $

The residual distribution is a probability distribution, and it is a distribution over the under-sampled tokens. All over-sampled tokens have zero probability in the residual distribution.

When we reject a sample from the over-sampled probability, we re-sample from the residual distribution.

Does this recover the target distribution?

Intuitively, the answer is yes. Rigorously, you can check out the proof of Theorem 1 in the paper.

The core step is to compute the total rejection probability which will trigger a re-sampling from the residual distribution.

This is $1 - \sum_{x_i} \min(p(x_i), q(x_i)) = \sum_{x_i} \max(0, p(x_i) - q(x_i))$. Magic! This is exactly the normalization constant of the residual distribution.

So the final sampling result is $min(p(x_i), q(x_i)) + \max(0, p(x_i) - q(x_i)) = p(x_i)$.

Summary

The key concepts of speculative sampling are:

  • over-sampled and under-sampled tokens
  • down-sample the over-sampled tokens
  • up-sample the under-sampled tokens
  • residual distribution
  • total rejection probability to trigger a re-sampling from the residual distribution
  • recovery of the target distribution
Saibo Geng
Saibo Geng
PhD student in EPFL

My research interest is improving LLM’s performance through decoding methods.