Simplex, sphere, and Fisher–Rao metric

manifolds
Author

Nicolas Boumal

Published

November 24, 2023

Abstract
For optimization problems on the simplex, the Fisher–Rao metric makes sense. This post reviews why that is, and spells out how to use that metric by simply lifting the problem to optimization on a standard sphere.

The Fisher–Rao metric on the interior of the simplex

Let the interior of the simplex in \(\Rd\) be denoted by \[ \Ddplus = \{ x \in \Rd : x_1 + \cdots + x_d = 1 \textrm{ and } x_1, \ldots, x_d > 0 \}. \] The tangent space at \(x\) is given by \[ \T_x\Ddplus = \{ u \in \Rd : u_1 + \cdots + u_d = 0 \} = \spann(\One)^\perp. \] The Fisher–Rao metric (also called Shahshahani metric) defines a Riemannian metric on \(\Ddplus\), as follows: \[ \inner{u}{v}_x = \sum_{i = 1}^{d} \frac{u_i v_i}{x_i}. \tag{1}\] Why? This is actually the Fisher information metric for a natural estimation problem on \(\Ddplus\).

Here is one way to see it. Each point \(x \in \Ddplus\) encodes a discrete probability distribution over \(d\) objects. Given \(n\) i.i.d. samples from that distribution results in \(d\) integers \(n_1, \ldots, n_d\) that sum to \(n\), where \(n_i\) is the number of times object \(i\) was selected. Those numbers follow a multinomial distribution with parameters \(x\) and \(n\). A natural question to ask here is: given a realization of those numbers, how best to estimate \(x\)? The log-likelihood is as follows: \[ \ell(n_1, \ldots, n_d | x) = \log\!\left( \prod_{i = 1}^{d} x_i^{n_i} \right) = \sum_{i = 1}^{d} n_i \log(x_i). \tag{2}\] The Euclidean gradient of that log-likelihood with respect to \(x\) in \(\Rd\) is: \[ \nabla\!\left( x \mapsto \ell(n_1, \ldots, n_d | x) \right)(x) = \begin{bmatrix} n_1 / x_1 \\ \vdots \\ n_d / x_d \end{bmatrix}. \tag{3}\] Since \(x\) is constrained to the orthogonal complement of \(\One\), it is clear that for \(x\) to be stationary we must have that the above gradient is equal to a constant vector. That is, if \(x\) is stationary, then there is a constant \(c\) such that \(n_i/x_i = c\) for all \(i\). Since \(x_1 + \cdots + x_d = 1\), it follows that \(c = n\). It is easy to see that the (unique) stationary point is indeed the optimum, so that the MLE for \(x\) is: \[ \hat{x} = \begin{bmatrix} n_1 / n \\ \vdots \\ n_d / n \end{bmatrix}. \] Moreover, we know that for a multinomial the expectation of \(n_i\) is exactly \(n x_i\). Thus, the MLE is unbiased: \[ \EE\{\hat x\} = x. \] Now for the covariance: \[ \mathrm{Cov}(\hat x) = \EE\left\{ (\hat x - x) (\hat x - x)\transpose \right\} \] Again, the properties of multinomials tell us that \[ \EE\{ (\hat x - x) (\hat x - x)\transpose \}_{ij} = \frac{1}{n^2} \EE\{ (n_i - \EE\{n_i\}) (n_j - \EE\{n_j\}) \} = \begin{cases} \frac{x_i(1-x_i)}{n} & \textrm{ if } i = j, \\ \frac{-x_ix_j}{n} & \textrm{ if } i \neq j. \end{cases} \] Thus, overall, \[ \mathrm{Cov}(\hat x) = \frac{1}{n} \left( \diag(x) - xx\transpose \, \right). \] Notice that the matrix \(\diag(x) - xx\transpose\) is the discrete Laplacian for the graph with weights \(x_ix_j\). This makes it easy to see that he covariance is positive semidefinite, with rank \(d - 1\) and kernel given by the span of the all-ones vector \(\One\). In particular, the covariance is positive definite when restricted to the tangent space \(\T_x\Ddplus\). The key fact is this:

The inverse of the covariance of the MLE restricted to the tangent space \(\T_x\Ddplus\) is exactly \(\diag(x)^{-1}\).

We know the maximum likelihood estimator (MLE) is—typically—efficient, in the sense that its covariance matches the inverse of the Fisher information: see the next section for confirmation in this case. It follows that the metric (Eq. 1) is the Fisher information metric associated to the estimation problem described above. The keyword here is Information Geometry.

To show the key fact, we need some linear algebra.

Lemma 1 Given \(x \in \Ddplus\), we have the following Moore–Penrose pseudo inverse: \[ \left(\diag(x) - xx\transpose\,\right)^\dagger = \diag(x)^{-1} - \begin{bmatrix} \One & x^{(-1)} \end{bmatrix} \begin{bmatrix} -\trace(\diag(x)^{-1}) / d^2 & 1/d \\ 1/d & 0 \end{bmatrix} \begin{bmatrix} \One\transpose \\ (x^{(-1)})\transpose \end{bmatrix}. \]

Proof. One option is to add \(\One\One\transpose\) to make the matrix invertible, invert with Woodbury, and subtract \(\frac{1}{d^2} \One\One\transpose\). You can also check it with matlab code:

d = 5;
x = rand(d, 1);
x = x / sum(x);
A = diag(x) - x*x';
Z1 = pinv(A)
K = [ones(d, 1), 1./x];
Z2 = diag(1./x) - K*[-sum(1./x)/d^2, 1/d ; 1/d, 0]*K'
norm(Z1 - Z2)  % Outputs a machine-precision zero

Fisher information metric directly

Actually, the detour through the MLE is unnecessary. Just go back to the Euclidean gradient (Eq. 3) and compute the Euclidean Hessian: \[ \nabla^2\!\left( x \mapsto \ell(n_1, \ldots, n_d | x) \right)(x) = -\begin{bmatrix} n_1 / x_1^2 & & \\ & \ddots & \\ & & n_d / x_d^2 \end{bmatrix}. \tag{4}\] The Fisher information matrix (FIM) \(I(x)\) is the expectation of the negative of that Hessian. Since \(\EE\{n_i\} = n x_i\), it follows that \[ I(x) = n \begin{bmatrix} 1 / x_1 & & \\ & \ddots & \\ & & 1 / x_d \end{bmatrix}. \] So, yes, it’s completely clear from this direct calculation that the Fisher–Rao metric is the Fisher information metric (up to the scaling by \(n\)). The story above only serves as a roundabout way to prove that the MLE reaches the Cramer–Rao bound, but we could have figured that out from general results.

Relation to the usual metric on the sphere

Let \(\Sdplus\) denote the positive orthant of the unit sphere in \(\Rd\). The map \(F \colon \Sdplus \to \Ddplus\) defined using entry-wise product \(\odot\) by \[ F(y) = y \odot y \] is a diffeomorphism. Let \(\Sdplus\) be a Riemannian submanifold of \(\Rd\) with the usual inner product \(\inner{u}{v} = u\transpose v\), and let \(\Ddplus\) be equipped with the Fisher–Rao metric. Then, for all \(y \in \Sdplus\) and tangent vectors \(u, v\) at \(y\), writing \(x = F(y)\), we have \[\begin{align} \inner{\D F(y)[u]}{\D F(y)[v]}_x & = \inner{2 y \odot u}{2 y \odot v}_x \\ & = 4 \sum_{i = 1}^{d} \frac{y_i u_i y_i v_i}{x_i} \\ & = 4 \sum_{i = 1}^{d} u_i v_i \\ & = 4\inner{u}{v}. \end{align}\] This gives a nice relation between the Fisher–Rao metric on the simplex and the usual metric on the unit sphere: up to a factor 4, the two manifolds are isometric (in a Riemannian sense). To compute the Fisher–Rao distance between two points \(x_1, x_2\) in the (interior of) the simplex, one option is to entry-wise square-root them, compute the Riemannian distance on the sphere (arccosine of inner product), then scale. This distance is indicative of how difficult it is to determine (by a statistical test) whether samples came from \(x_1\) or \(x_2\).

Thus, a convenient way to minimize a function \(f\) over the whole simplex while using the Fisher–Rao metric on its interior is to minimize \(g = f \circ F\) over the unit sphere with its standard metric. A neat fact in this regard is that if \(y \in \Sd\) is second-order critical for \(g\), then \(x = F(y)\) is first-order critical for \(f\) on the simplex: see (Levin, Kileel, and Boumal 2024) for this and a more general picture of such facts. In particular, if \(f\) is convex, then \(g \colon \Sd \to \reals\) has the property that second-order critical points are optimal. Optimizing on the sphere has the added benefit that we need not worry about the boundary of the simplex (where the Fisher–Rao metric breaks down, but the metric on the sphere still works).

Relevance of that metric for conditioning of MLE landscape

The MLE \(\hat{x}\) is the optimizer of the log-likelihood function. Thus, at that point, the gradient is zero and the Hessian is semidefinite—regardless of the metric we choose. But depending on the metric, the condition number of the Hessian might be different. The general story goes like this. The negative log-likelihood (to be minimized) is some function \(f \colon \Ddplus \to \reals\). We compute its gradient and Hessian with respect to the Fisher–Rao metric. Let \(\bar f\) be the smooth extension of \(f\) to a neighborhood of \(\Ddplus\) in \(\Rd\) defined by the same formula as \(f\). We’ll compute its gradient and Hessian with respect to the Euclidean metric in \(\Rd\). We have \[ u\transpose \diag(x)^{-1} \nabla f(x) = \inner{\nabla f(x)}{u}_x = \D f(x)[u] = \D \bar f(x)[u] = \inner{\nabla \bar f(x)}{u} = u\transpose \nabla \bar f(x). \] This holds for all \(u\) in \(\T_x \Ddplus = \spann(\One)^\perp\), so \[ P \diag(x)^{-1} \nabla f(x) = P \nabla \bar f(x), \] where \(P\) is the orthogonal projector to \(\spann(\One)^\perp\). We find: \[ \nabla f(x) = \diag(x) \nabla \bar f(x) + \alpha x, \] where \(\alpha\) is some real number such that \(\One\transpose \nabla f(x) = 0\). Explicitly: \[ \nabla f(x) = (\diag(x) - xx\transpose) \nabla \bar f(x). \] It’s easy to pick a smooth extension for \(\nabla f\); here is one: \[ G(x) = (\diag(x) - xx\transpose) \nabla \bar f(x). \] The directional derivative is: \[ \D(\nabla f)(x)[u] = \D G(x)[u] = (\diag(u) - ux\transpose - xu\transpose) \nabla \bar f(x) + (\diag(x) - xx\transpose) \nabla^2 \bar f(x)[u]. \] At the MLE, we have \(\nabla \bar f(\hat x) = - n \One\) and \(\nabla^2 \bar f(\hat x) = n^2 \diag(1/n_1, \ldots, 1/n_d)\). Also at the MLE, we have \(\nabla^2 f(\hat x) = \D(\nabla f)(\hat x)\) since it’s a critical point. Thus: \[ \nabla^2 f(\hat x) = n^2 (\diag(\hat x) - \hat x (\hat x)\transpose) \diag(1/n_1, \ldots, 1/n_d) = nI - \begin{bmatrix} n_1 \\ \vdots \\ n_d \end{bmatrix} \One\transpose. \] That expression is only valid on \(\spann(\One)^\perp\), hence we can disregard the second term. This all confirms that the Riemannian Hessian at the MLE has condition number 1.

Entropy

Let \(\phi \colon \Ddplus \to \reals\) be defined by \[ \phi(x) = \sum_{i = 1}^{d} x_i \log(x_i). \] This is the entropy of the distribution \(x\). Let \(\bar{\phi}\) de defined by the same formula, only on the whole positive orthant, which is open in \(\Rd\). Thus, \[\begin{align} \nabla \bar{\phi}(x) & = \begin{bmatrix} \log(x_1) + 1 \\ \vdots \\ \log(x_d) + 1 \end{bmatrix}, & \nabla^2 \bar{\phi}(x) & = \begin{bmatrix} \frac{1}{x_1} & & \\ & \ddots & \\ & & \frac{1}{x_d} \end{bmatrix}. \end{align}\] Then another perspective on the Fisher–Rao metric is that it is the Riemannian submanifold metric for \(\Ddplus\) embedded in \(\Rdplus\) with the Hessian-of-entropy metric. Equivalently, since with \(P\) the orthogonal projector to \(\spann(\One)^\perp\) we have that the gradient and Hessian of \(\phi\) with respect to the Euclidean metric restricted to \(\Ddplus\) are \(P \nabla \bar{\phi}(x)\) and \(P \nabla^2 \bar{\phi}(x) P\) respectively, we see that the Fisher–Rao metric on \(\Ddplus\) is the Hessian metric induced by entropy. This also connects well with the idea that entropy is the expected amount of information revealed by samples, and we see here that it matches the FIM, which itself is the expectation of the Hessian of the log-likelihood—up to signs etc. Actually, \(\phi\) itself is the expected log-likelihood (Eq. 2).

Not to be confused with

Over the positive orthant, it’s also common to do an \(\exp\)\(\log\) change of variable, as an isometry with \(\Rd\). That leads to a complete metric over the positive orthant, with \(G(x) = \diag(1/x_1^2, \ldots, 1/x_d^2)\). Note that the second derivative of \(\log(x)\) is \(1/x^2\), so this is the Hessian metric induced by the barrier function \(x \mapsto \sum_{i=1}^{d} \log(x_i)\).

References

Levin, E., J. Kileel, and N. Boumal. 2024. “The Effect of Smooth Parametrizations on Nonconvex Optimization Landscapes.” Mathematical Programming. https://doi.org/10.1007/s10107-024-02058-3.