Constrained Decoding is Posterior Inference

If you are familiar with constrained decoding, you might have occasionally heard people mentioning the word “posterior inference” in the context of constrained decoding.

In this post, I will explain why constrained decoding can(should) be seen as a form of posterior inference.

What is Posterior Inference?

Posterior inference simply means inferring the posterior distribution of a model given some observations. It is a fundamental problem in Bayesian statistics and machine learning.

Given two random variables $X$ and $Y$, the posterior distribution of $X$ given $Y$ is defined as: $$ p(X|Y) = \frac{p(Y|X)p(X)}{p(Y)} $$ where $p(X)$ is the prior distribution of $X$, $p(Y|X)$ is the likelihood of $Y$ given $X$, and $p(Y)$ is the marginal likelihood of $Y$.

If we develop $P(Y)$, we get: $$ p(Y) = \sum_{X} p(Y|X)p(X) $$ Thus $$ p(X|Y) = \frac{p(Y|X)p(X)}{\sum_{X} p(Y|X)p(X)} $$

Check more details about posterior inference in the Bayesian Inference page.

Constrained Decoding as Posterior Inference

In the context of language models, we can see the decoding under constraints as a form of posterior inference.

The variable $X$ is the sequence of tokens and $Y$ is the constraints.

$P(X)$ is the prior distribution of the sequence of tokens, which is given by the LLM in a factorized form: $$ p(X) = \prod_{i=1}^T p(x_i|x_{<i}) $$

$P(Y|X)$ is the likelihood of the constraints given the sequence of tokens. We can simply define it as: $$ p(Y|X) = \begin{cases} 1 & \text{if the constraints are satisfied} \ 0 & \text{otherwise} \end{cases} $$

We denote $Y$ by $C$ for constraints and we get

$$ p(X|C) = \frac{p(C|X)p(X)}{\sum_{X} p(C|X)p(X)} $$

where $p(C|X)$ is 1 if the constraints are satisfied and 0 otherwise and $p(X)$ is the prior distribution of the sequence of tokens given by the LLM.

For $X$ that satisfies the constraints, we have

$$ p(X|C) = \frac{p(X)}{\sum_{X} p(C|X)p(X)} $$

The denominator is problematic because it requires us to sum over all possible sequences of tokens that satisfy the constraints. This is infeasible in practice because the number of possible sequences is exponential in the length of the sequence.

However, we may not always need to compute the denominator because we are often interested in the most likely sequence of tokens that satisfy the constraints. Given two sequences of tokens $X_1$ and $X_2$ that satisfy the constraints, we have $$ p(X_1|C) > p(X_2|C) \Leftrightarrow p(X_1) > p(X_2) $$

But in practice, the generation process of LLM is factorized into conditional probabilities of each token given the previous tokens.

Thus we want to have a factorized form of the posterior distribution $p(X|C)$ $$ p(X|C) = \prod_{i=1}^T p(x_i|x_{<i}, C) $$

where $p(x_i|x_{<i}, C)$ is the conditional probability of the token $x_i$ given the previous tokens and the constraints.

The only general way to compute $p(x_i|x_{<i}, C)$ is to compute the marginal likelihood of the token $x_i$ given the previous tokens and the constraints.

$$ p(x_i|x_{<i}, C) = \frac{p(x_i,x_{<i}, C)}{\sum_{x_i} p(x_i,x_{<i}, C)} $$ where

$$ p(x_i, x_{<i}, C) = \sum_{X} p(X, x_i, x_{<i}, C) $$

Again, this is infeasible in practice because the number of possible $X$ is infinite.

Example with forbidden tokens

Let’s consider a simple example where we want to ensure that the token “the” is not present in the generated sequence. We have a simple constraint $C$ that is 1 if the token “the” is not present in the sequence and 0 otherwise. (This is a fairly difficult constraint to satisfy, because “the” is a very common token in the English language and it’s hard to work around it.)

If you give LLM a global system prompt like “Don’t use the token ’the’ in the following conversation” and proceed with several round of dialogue, the LLM will fail to satisfy the constraint.

How can we solve this problem?

A simple yet not good enough approach

One simple approach is to mask the token “the” in the output vocabulary and this ensures that the token “the” is not generated.

This works!

But it’s not good enough because it is not giving us the exact posterior distribution $p(X|C)$ but rather an (possibly very bad) approximation.

$$ p^c(X) := \prod_{i=1}^T p^c(x_i|x_{<i}) $$ where $p^c(x_i|x_{<i}) = \frac{p(x_i|x_{<i}) \times c(x_{<=i})}{\sum_{x_i} p(x_i|x_{<i}) \times c(x_{<=i})} $

Thus we have $$ p^c(X) = \prod_{i=1}^T \frac{p(x_i|x_{<i}) \times c(x_{<=i})}{\sum_{x_i} p(x_i|x_{<i}) \times c(x_{<=i})} $$

For a sentence that satisifies the constraint, we have $c(x_{<=i}) = 1$ and we have $$ p^c(X) = \frac{p(x_i|x_{<i})}{\sum_{x_i} p(x_i|x_{<i}) \times c(x_{<=i})} $$ thus $$ p^c(X) = \prod_{i=1}^T \frac{p(x_i|x_{<i})}{\sum_{x_i} p(x_i|x_{<i}) \times c(x_{<=i})} $$ , which is not the exact posterior distribution $p(X|C)$ we want above.

Two better approximating approaches

The goal is to compute the exact posterior distribution $p(X|C)$. This is not a new problem and it’s been studied in many different contexts.

Saibo Geng
Saibo Geng
PhD student in EPFL

My research interest is improving LLM’s performance through decoding methods.