Speculative Sampling Explained
Speculative Sampling
The idea of speculative sampling is to use a draft sampling to achieve the same sampling result as the target sampling.
We have a target sampling distribution $p(x)$ and a draft sampling distribution $q(x)$. For each token $x_i$, we have two probabilities: $p(x_i)$ and $q(x_i)$, and we have:
- either $p(x_i) > q(x_i)$,
- or $p(x_i) \leq q(x_i)$,
If we directly sample from $q(x)$, we will get a sample $x$ that is not from the target distribution $p(x)$, and it is
- over-sampled if $q(x_i) > p(x_i)$,
- under-sampled if $q(x_i) < p(x_i)$.
The core trick of speculative sampling is to design a smart rejection method to down-sample the over-sampled tokens and up-sample the under-sampled tokens. This way we can achieve the same sampling result as the target sampling.
How to down-sample the over-sampled tokens?
The idea is to use a rejection method to reject the over-sampled probability. We only accept the sample with a probability of $p(x_i) / q(x_i)$. As our assumption was over-sampled, we have $p(x_i) < q(x_i)$, and the acceptance probability is less than 1. (This is also why they have $\min(1, p(x_i) / q(x_i))$ in the paper. The 1 is to deal with the case where $p(x_i) > q(x_i)$, which is not over-sampled.)
How to up-sample the under-sampled tokens?
The idea is to give the under-sampled probability an additional chance to be sampled. First, we do not reject the under-sampled probability. Second, when an rejection is triggered from the over-sampled probability, we re-sample from the under-sampled probability.
Here we first define a new distribution which represents the under-sampled probability. We call it the residual distribution.
The residual distribution is defined as follows: $ r(x_i) = \frac{\max(0, p(x_i) - q(x_i))}{\sum_{x_i} \max(0, p(x_i) - q(x_i))} $
The residual distribution is a probability distribution, and it is a distribution over the under-sampled tokens. All over-sampled tokens have zero probability in the residual distribution.
When we reject a sample from the over-sampled probability, we re-sample from the residual distribution.
Does this recover the target distribution?
Intuitively, the answer is yes. Rigorously, you can check out the proof of Theorem 1 in the paper.
The core step is to compute the total rejection probability which will trigger a re-sampling from the residual distribution.
This is $1 - \sum_{x_i} \min(p(x_i), q(x_i)) = \sum_{x_i} \max(0, p(x_i) - q(x_i))$. Magic! This is exactly the normalization constant of the residual distribution.
So the final sampling result is $min(p(x_i), q(x_i)) + \max(0, p(x_i) - q(x_i)) = p(x_i)$.
Summary
The key concepts of speculative sampling are:
- over-sampled and under-sampled tokens
- down-sample the over-sampled tokens
- up-sample the under-sampled tokens
- residual distribution
- total rejection probability to trigger a re-sampling from the residual distribution
- recovery of the target distribution