\(f\)-divergence

In probability theory, \(f\)-divergence captures a family of measures between two distributions \(P\) and \(Q\). In this article, we focus on the application of \(f\)-divergence in AI/ML, and particularly LLM training. As we focus on this application, there are a couple of simplifications that come naturally: (a) we deal with discrete probabilities, and (b) we do not need to worry about various singularities. To keep our notation simple, we will use Einstein summation notation throughout.

1 Definition

When \(f\) is a convex function and \(f(1)=0\), \(f\)-divergence is defined as:

\[D_f(P || Q) := q(x_i)f \left(\frac{p(x_i)}{q(x_i)}\right)=\mathop{\mathbb{E}}_{x\sim Q}[f(r)]\]

where I introduced a shorthand notation \(r(x_i) := p(x_i)/q(x_i)\) and used Einstein’s summation notation (i.e., repeated indices imply summation). Throughout this article I will also use \(p_i = p(x_i)\) to keep the notation simple.

It is worth looking at some examples.

  1. In information theory, \(Q\) is the null hypothesis i.e. what we know of the world. \(P\) is the observation, representing the new thing we are observing. \(f\)-divergence measures how much the observation deviates from the known \(Q\).
  2. In LLMs, \(P\) is the true probability distribution, such as human language, and \(Q\) is the distribution we obtain from the model. We often measure the divergence of our model to the real world.
  3. Another example is when \(P\) is the teacher model distribution and \(Q\) is the student model distribution.

One interesting thing to note is that the expectation is computed against \(Q\). Practically, this makes sense as we often have control over the null hypothesis or the model and want to observe the true distribution or human language. This subtlety often trips people up, especially when trying to understand \(f\)-divergence in terms of KL-divergence.

2 Example: KL-divergence

KL divergence is one of the most commonly used metrics for measuring the difference between \(P\) and \(Q\). Concretely, the definition of KL-divergence is:

\[D_{KL}(P || Q) := -p_i\log q_i + p_i\log p_i\]

In other words, it represents the extra bits needed on average to encode samples from \(P\) when using \(Q\) instead of the optimal encoding. As we will show in the next section, \(f\)-divergence is always non-negative; therefore, we can use KL-divergence as a loss function where \(Q\) will always try to approximate \(P\).

KL-divergence is often written as:

\[D_{KL} = \mathop{\mathbb{E}}_{x\sim P}[\log r]\]

Here the expectation is taken over \(P\), in contrast to the definition of \(f\)-divergence where it’s taken over \(Q\). However, we can rewrite this to fit into the framework of \(f\)-divergence:

\[D_{KL} = q_i \frac{p_i}{q_i} \log r = \mathop{\mathbb{E}}_{x \sim Q}[r\log r]\]

Note that \(f(r) := r\log r\) is indeed convex, as \(f''(r) = 1/r > 0\) for all \(r > 0\), and \(f(1) = 0\).

3 \(f\)-divergence is always non-negative

We can show that \(f\)-divergence is always non-negative by applying Jensen’s inequality:

\[ D_f(P||Q) = \mathop{\mathbb{E}}_Q [f(r)] \ge f(\mathop{\mathbb{E}}_Q [r]) = f\left(q_i \frac{p_i}{q_i}\right) = f\left(\sum p_i\right) = f(1) = 0\]

Equality is reached when \(P\) = \(Q\). This is because in Jensen’s inequality the equality holds if the function \(f\) is linear or \(r\) is constant for all \(x_i\). If \(f\) is strictly convex, then \(r\) must be constant and since the probabilities must sum to \(1\), we have \(p_i = q_i\) for all \(i\). Therefore, we can conclude that for \(f\)-divergence to be \(0\), \(P\) and \(Q\) must be identical.

In the next article, we will look closer into KL-divergence, perhaps the most famous \(f\)-divergence. We will discuss the difference between forward and reverse KL, and show how to approximate it in practice.