It is recommended to read this post as a PDF, available here.

Introduction

Suppose you’re in Vegas, and you’ve had the misfortune of encountering a crooked croupier (let’s call him $p$). You suspect he is using loaded dice. You’ve been watching him for a while, and you’ve developed a theory for how the dice are loaded (call your theory $q$). The KL divergence between $q$ and $p$ is the measure of how surprised you and your wallet would be, on average, if you were betting according to your theory $q$ but the dice were behaving according to $p$. This surprise is also a measure of the distance between your and the croupier’s distribution—how far was your guess?

Imagine now that perhaps you rightfully called out this crooked croupier, which led to us getting kicked out of that casino. But of course we must get even. So we find some long-suffering friend of yours. We need a method of long-distance communication. We decide that you’ll blink at him—something like Morse code—from the buffet next door. So we prep him to read blinks, and send him in.

To make the code as efficient as possible, we use information theory and our distribution $q$. If a play $x$ appears with probability $q(x)$ according to us, the optimal number of bits to encode it is $-\log_2 q(x)$. We’ll note that this assigns fewer bits to more likely things, thus reducing the total amount we need to blink in the direction of our friend.

$$ -\log_2 q(x) - (-\log_2 p(x)) = \log_2\frac{p(x)}{q(x)}. $$

If you leave your friend playing for a while at this table, the expected waste is: \begin{equation} \mathbb{E}_{x \sim p} \left[ \log\frac{p(x)}{q(x)} \right] = \sum_x p(x) \cdot \log_2 \frac{p(x)}{q(x)}. \end{equation}

This is forward KL divergence—and can be thought of as the wasted bits you need to send over based on the difference between your distribution and the true distribution.

KL Divergence and Entropy

Playing with this equation, we can discover something else quite insightful. First we know that: \begin{align} \mathrm{KL}[p|q] &= \sum_x p(x) \log\frac{p(x)}{q(x)}. \end{align}

Let’s expand that log ratio: \begin{align} \mathrm{KL}[p|q] &= \sum_x p(x),[\log p(x) - \log q(x)]. \end{align}

Separate the terms: \begin{align} \mathrm{KL}[p|q] = \sum_x p(x) \log p(x) - \sum_x p(x) \log q(x). \end{align}

By definition,

$$ H(p) = - \sum_x p(x)\log p(x) $$

is the entropy of the true distribution (“entropy of reality”), and

$$ H_p(q) = - \sum_x p(x)\log q(x) $$

is the cross-entropy of using $q$ when samples come from $p$.

Therefore: \begin{equation} \boxed{\mathrm{KL}[p|q] = H_p(q) - H(p)} \end{equation}

KL divergence can also be thought of as the regret, or “surprise tax,” you pay for using the wrong distribution $q$ when the true distribution is $p$: it is the gap between the code length you actually incur (cross-entropy $H_p(q)$) and the optimal code length you could have achieved if you had known $p$ (entropy $H(p)$).

KL $\geq 0$

The code tuned to the true distribution $p$ is, in expectation, unbeatable. At best you tie it when $q = p$; otherwise you pay the surprise tax. Formally, $H(p)$ is the optimal average code length when the world is $p$. $H_p(q)$ is the average code length you get when you insist the world looks like $q$. In expectation, you can’t beat the optimal code, and you only match it when you guessed perfectly. The difference $\mathrm{KL}[p|q]$ should therefore always be $\ge 0$.

This is also equivalent to Gibbs’ Inequality, which we’ll succinctly derive now.

Lemma. For all $x > 0$,

$$ > \log x \le x - 1, > $$

with equality if and only if $x = 1$.

This is true because $\log x$ is concave. Now apply this to KL. Start with:

$$ \mathrm{KL}[p\|q] = \sum_x p(x)\log \frac{p(x)}{q(x)}. $$

Let

$$ u(x) = \frac{q(x)}{p(x)}, $$

so that

$$ \log \frac{p(x)}{q(x)} = -\log u(x). $$

Since probability values cannot be negative, $p(x) \geq 0, q(x) \geq 0$, the assumption $u(x) > 0$ holds for all values. From $\log u \le u - 1$ for all $u > 0$, we get

$$ -\log u(x) \ge 1 - u(x). $$

Multiply both sides by $p(x)$:

$$ p(x)\log \frac{p(x)}{q(x)} \ge p(x)\bigl(1 - u(x)\bigr) = p(x) - q(x). $$

Now sum over all $x$:

$$ \begin{aligned} \sum_x p(x)\log \frac{p(x)}{q(x)} &\ge \sum_x \bigl(p(x) - q(x)\bigr) &= \sum_x p(x) - \sum_x q(x) \\ &= 1 - 1 = 0 \end{aligned} $$

since both $p$ and $q$ are probability distributions and thus each sum to 1. The left-hand side is exactly $\mathrm{KL}[p|q]$, so we conclude

\[ \mathrm{KL}[p\|q] \ge 0, \]

with equality if and only if $\log u(x) = u(x) - 1$ for all $x$, i.e. $u(x) = 1$ for all $x$, which means $p(x) = q(x)$ everywhere.

Log Likelihood

Suppose the real world has some unknown distribution $p(x)$, and we build a model $q_\theta(x)$ with parameters $\theta$ to approximate it. In practice, we fit $\theta$ by maximizing the log-likelihood of the observed data:

$$ \max_\theta \; E_{x \sim p}[\log q_\theta(x)]. $$

This has a close relationship with KL Divergence. Start from the forward KL:

$$ \begin{aligned} \mathrm{KL}[p\|q_\theta] &= \sum_x p(x)\log \frac{p(x)}{q_\theta(x)} \\ &= \sum_x p(x)\log p(x) - \sum_x p(x)\log q_\theta(x). \end{aligned} $$

The first term,

$$ \sum_x p(x)\log p(x), $$

depends only on the true distribution $p$, which we do not control. So, as a function of $\theta$,

$$ \mathrm{KL}[p\|q_\theta] = \text{constant} - E_{x \sim p}[\log q_\theta(x)]. $$$$ \begin{aligned} \theta^\star &= \arg\max_\theta E_{x \sim p}[\log q_\theta(x)] \\ &= \arg\min_\theta \mathrm{KL}[p\|q_\theta]. \end{aligned} $$

Maximum likelihood training is choosing the model whose predictions make the observed world least surprising on average. Phrased yet another way: among all $q_\theta$, we pick the one that wastes the fewest extra bits compared to the (unknowable) true compressor for $p$. The closer $q_\theta$ is to $p$, the better the model is able to compress its data distribution, and the more it “understands”.

Forward vs Reverse KL Divergence

Where forward KL is $\mathrm{KL}[\text{reality}|\text{guess}]$, reverse KL is $\mathrm{KL}[\text{guess}|\text{reality}]$. While their equations look nearly identical, the behavior of a policy iterating under either KL could not be more different.

Forward KL is mode-covering: in a multi-modal distribution, it tries to split the difference and cover as much as possible. Imagine $p(x) = 0.01$, and $q(x) = 0.0$, then $\log(p(x)/q(x)) = \infty$. Even a tiny bit of probability in $p$, when $q$ says “impossible,” makes KL divergence infinite, so forward KL spreads the distribution out.

Reverse KL is mode-seeking: it tends to pick one mode of $p$ and match it perfectly. When $q(x) = 0$, there is no penalty. But when $p(x) = 0$ and $q(x) > 0$, then $\mathrm{KL} = \infty$. So, reverse KL says: “You can ignore regions where $p$ is small, but you absolutely cannot claim something is possible when it is actually impossible.”

Forward KL is used by default in many contexts. Reverse KL is used in generative models like VAEs, where we’d like clear, sharp faces from specific ethnicity/ages, rather than blurry “average” human faces. We also use reverse KL in model distillation, where a large model might say an answer “could be A, B, or C” and you want your small model to model “definitely A” instead of being unable to parse the nuance of A/B/C and being unable to learn at all.

Figure — Evolution of $q$ under forward KL (mode-covering) versus reverse KL (mode-seeking) optimization. Left: Initial configuration with $q$ starting between two modes of $p$. Right: Final convergence after optimization—forward KL spreads to cover both peaks while reverse KL commits to matching a single mode perfectly.

KL Divergence Estimators

This section is heavily inspired by this blog post. It is slightly lighter on mathematical theory than the source, and puts slightly more effort into motivating the various estimators—all flaws my own.

In RLHF, KL Divergence is used to prevent models from going completely off the rails. We have a fixed reference model $\pi_\text{ref}$ and the updating policy $\pi$. We define $\pi(x_t \mid x_{<t})$ as the distribution for position $t$ conditioned on all previous tokens up to $t$. The full KL penalty would be:

$$ \begin{aligned} \text{KL penalty} &= \mathrm{KL}[\pi(x_t \mid x_{0-t})\ || \pi_{\text{ref}}(x_t \mid x_{0t})] \\ &= \sum_{v \in \text{vocab}} \pi(v \mid x_{0-t}) \log \frac{\pi(v \mid x_{0-t})}{\pi_{\text{ref}}(v \mid x_{0-t})} \end{aligned} $$

Computing this exactly would require evaluating all probabilities $\pi(v \mid x_{<t})$ and $\pi_\text{ref}(v \mid x_{<t})$ for every token $v$ in the vocabulary, at every position $t$ in the sequence, for every sequence in the batch. With typical values (vocab size = 50,000, sequence length = 2,048, batch size = 32), this would be roughly 3.3 billion probability evaluations per batch—which can be too memory- or computationally-inefficient. So, we need to estimate the value instead.

A good estimator is unbiased (it has the same mean as the original) and preferably has low variance.

A naive estimator would be:

$$ \begin{aligned} k_1 &= - \log \frac{p(x)}{q(x)} = - \log r, \\ \hat{\mathrm{KL}} &= E[k_1] = E[- \log r]. \end{aligned} $$

It is unbiased, but it has very high variance. This value can often be negative, even though $\mathrm{KL} \geq 0$.

We can sample from $q$, and for each sample $x$ we can compute

$$ \log r(x) = \log \frac{p(x)}{q(x)}. $$

Any estimator we build has to be some function $g(\log r)$. So our goal is to pick $g$ such that $E_q[g(\log r)]$ approximates KL well.

Let $t = \log r$ to keep things clean. When $p = q$, we have $r = 1$. Ideally, our estimator should have the following properties:

  1. When $p = q$, KL is zero. So we want $g(0) = 0$.
  2. $g(\cdot)$ locally matches KL when $p$ and $q$ are close to each other (often the case in practice). When $p$ and $q$ are close, small perturbations don’t matter, so we want $g’(0) = 0$.
  3. It has lower per-sample variance than $- \log r$. Here, in order to avoid a dive into Fisher information theory, we assume that the naive estimator $- \log r$ has second derivative 1 in the right coordinates. To measure distance on the same scale, we want $g’’(0) = 1$ (read the original blog if you want to dig further in).

Looking at the Taylor expansion of $g$ around 0: \begin{align*} g(t) = a_0 + a_1 t + \tfrac{1}{2} a_2 t^2 + O(t^3). \end{align*} Given our constraints, we have $a_0 = g(0) = 0$, $a_1 = g’(0) = 0$ and $a_2 = g’’(0) = 1$.

So

$$ g(t) = \tfrac{1}{2} t^2 + O(t^3), $$

and given we want the simplest $g$, we drop the higher-order terms and get

$$ g(\log r) = \frac{1}{2} (\log r)^2. $$

Some nice things fall out of this estimator:

  • It’s always positive (like our true KL).
  • It measures a distance between $p$ and $q$.
  • It has lower variance than our naive estimator.

We also have

$$ E[K_2] = \mathrm{KL}(q\|p) + O(\delta^3), $$

for small deviations $\delta$ between $p$ and $q$. As you will note, this is not unbiased—though the bias is small in practice. We can be quite happy with our $k_2$ estimator.

But we can yet do better! Is there a way to make an unbiased estimator with lower variance? Quoting the original blog: “The general way to lower variance is with a control variate—take $k_1$ and add something that has expectation 0 but is negatively correlated with $k_1$.”

What do we know that might have expectation zero? Well, we know that $E[r] = 1$. And so $(r - 1)$ is guaranteed to have zero expectation. If we can find a $\lambda$ such that

$$ -\log r + \lambda (1 - r) $$

has lower variance, we’ll have a lower-variance, unbiased estimator.

Calculating the optimal $\lambda$ is hard, but we can estimate a reasonable value of $\lambda$ to be 1 (see the original blog for why). This gives the $k_3$ estimator:

$$ k_3 = (r - 1) - \log r. $$

This is an example of a Bregman divergence—the gap between a convex curve and the tangent line drawn from the curve at some point $x$.

Further Readings of Note

For the curious reader who wants to pursue an even deeper understanding, this document lacks coverage on the following:

  • The relationship between log-likelihood and KL
  • $f$-divergences and Bregman divergences
  • Local geometry and Fisher information

I’d welcome any amendments, fixes, or improvements to this document.