7

I found the following PyTorch code (from this link)

-0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp())

where mu is the mean parameter that comes out of the model and sigma is the sigma parameter out of the encoder. This expression is apparently equivalent to the KL divergence. But I don't see how this calculates the KL divergence for the latent.

nbro
  • 39,006
  • 12
  • 98
  • 176
user8714896
  • 717
  • 1
  • 4
  • 21

2 Answers2

8

This is the analytical form of the KL divergence between two multivariate Gaussian densities with diagonal covariance matrices (i.e. we assume independence). More precisely, it's the KL divergence between the variational distribution

$$ q_{\boldsymbol{\phi}}(\mathbf{z}) = \mathcal{N}\left(\mathbf{z} ; \boldsymbol{\mu}, \mathbf{\Sigma} = \boldsymbol{\sigma}^{2}\mathbf{I}\right) = \frac{\exp \left(-\frac{1}{2}\left(\mathbf{z} - \boldsymbol{\mu}\right)^{\mathrm{T}} \mathbf{\Sigma}^{-1}\left(\mathbf{z}-\boldsymbol{\mu} \right)\right)}{\sqrt{(2 \pi)^{J}\left|\mathbf{\Sigma}\right|}} \tag{1}\label{1} $$

and the prior (it's the same as above, but with mean and covariance equal to the zero vector and the identity matrix, respectively)

$$ p(\mathbf{z})=\mathcal{N}(\mathbf{z} ;\boldsymbol{0}, \mathbf{I}) = \frac{\exp \left(-\frac{1}{2}\mathbf{z}^{\mathrm{T}}\mathbf{z}\right)}{\sqrt{(2 \pi)^{J}}} \tag{2}\label{2} $$

where

  • $\boldsymbol{\mu} \in \mathbb{R}^J$ is the mean vector (we assume column vectors, so $\boldsymbol{\mu}^T$ would be a row vector)
  • $\mathbf{\Sigma} = \boldsymbol{\sigma}^{2}\mathbf{I} \in \mathbb{R}^{J \times J}$ is a diagonal covariance matrix (with the vector $\boldsymbol{\sigma}^{2}$ on the diagonal of the identity)
  • $\mathbf{z} \in \mathbb{R}^J$ is a sample (latent vector) from these Gaussians with dimensionality $J$ (or, at the same time, the input variable of the density)
  • $\left|\mathbf{\Sigma}\right| = \operatorname{det} \mathbf{\Sigma}$ is the determinant (so a number) of the diagonal covariance matrix, which is just the product of the diagonal elements for a diagonal matrix (which is the case); so, in the case of the identity, the determinant is $1$
  • $\boldsymbol{0} \in \mathbb{R}^J$ is the zero vector
  • $\mathbf{I} \in \mathbb{R}^{J \times J}$ is an identity matrix
  • $\mathbf{z}^{\mathrm{T}}\mathbf{z} = \sum_{i=1}^J z_i^2 \in \mathbb{R}$ is the dot product (hence a number)

Now, the (negative of the) KL divergence is defined as follows

\begin{align} -D_{K L}\left(q_{\boldsymbol{\phi}}(\mathbf{z}) \| p(\mathbf{z})\right) &= \int q_{\boldsymbol{\phi}}(\mathbf{z})\left(\log p(\mathbf{z})-\log q_{\boldsymbol{\phi}}(\mathbf{z})\right) d \mathbf{z} \\ &= \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \log p(\mathbf{z})-\log q_{\boldsymbol{\phi}}(\mathbf{z})\right] \label{3}\tag{3} \end{align}

Given that we have logarithms here, let's compute the logarithm of equations \ref{1} and \ref{2}

\begin{align} \log \left( \mathcal{N}\left(\mathbf{z} ; \boldsymbol{\mu}, \mathbf{\Sigma} \right) \right) &= \dots \\ &= -\frac{1}{2}(\mathbf{z}-\boldsymbol{\mu})^{\mathrm{T}} \mathbf{\Sigma}^{-1}(\mathbf{z}-\boldsymbol{\mu})-\frac{J}{2} \log (2 \pi)-\frac{1}{2} \log |\mathbf{\Sigma} | \end{align}

and

\begin{align} \log \left( \mathcal{N}(\mathbf{z} ;\boldsymbol{0}, \mathbf{I}) \right) &= \dots \\ &= -\frac{1}{2}\mathbf{z}^{\mathrm{T}} \mathbf{z}-\frac{J}{2} \log (2 \pi) \end{align}

We can now replace these in equation \ref{3} (below, I have already performed some simplifications, to remove verbosity, but you can check them!)

\begin{align} \frac{1}{2} \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ -\mathbf{z}^{\mathrm{T}} \mathbf{z} + (\mathbf{z}-\boldsymbol{\mu})^{\mathrm{T}} \mathbf{\Sigma}^{-1}(\mathbf{z}-\boldsymbol{\mu}) + \log |\mathbf{\Sigma} | \right] \tag{4}\label{4} \end{align} Now, given that $\mathbf{\Sigma}$ is diagonal and the log of a product is just a sum of the logarithms, we have $\log |\mathbf{\Sigma} | = \sum_{i=1}^J \log \sigma_{ii}$, so we can continue

\begin{align} \frac{1}{2} \left( - \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \mathbf{z}^{\mathrm{T}} \mathbf{z} \right] + \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ (\mathbf{z}-\boldsymbol{\mu})^{\mathrm{T}} \mathbf{\Sigma}^{-1}(\mathbf{z}-\boldsymbol{\mu}) \right] + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \mathbf{z}^{\mathrm{T}} \mathbf{z} \right] + \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \operatorname{tr} \left( \mathbf{\Sigma}^{-1}(\mathbf{z}-\boldsymbol{\mu}) (\mathbf{z}-\boldsymbol{\mu})^{\mathrm{T}} \right) \right] + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \mathbf{z}^{\mathrm{T}} \mathbf{z} \right] + \operatorname{tr} \left( \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \mathbf{\Sigma}^{-1}(\mathbf{z}-\boldsymbol{\mu}) (\mathbf{z}-\boldsymbol{\mu})^{\mathrm{T}} \right] \right) + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \mathbf{z}^{\mathrm{T}} \mathbf{z} \right] + \operatorname{tr} \left( \mathbf{\Sigma}^{-1} \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ (\mathbf{z}-\boldsymbol{\mu}) (\mathbf{z}-\boldsymbol{\mu})^{\mathrm{T}} \right] \right) + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \mathbf{z}^{\mathrm{T}} \mathbf{z} \right] + \operatorname{tr} \left( \mathbf{\Sigma}^{-1} \mathbf{\Sigma} \right) + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \mathbf{z}^{\mathrm{T}} \mathbf{z} \right] + J + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \mathbb{E}_{q_{\boldsymbol{\phi}}(\mathbf{z})} \left[ \operatorname{tr} \left( \mathbf{z} \mathbf{z}^{\mathrm{T}} \right) \right] + \sum_{i=1}^J 1 + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \operatorname{tr} \left( \mathbf{\Sigma}\right) - \operatorname{tr} \left( \boldsymbol{\mu} \boldsymbol{\mu}^T \right) + \sum_{i=1}^J 1 + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \left( - \sum_{i=1}^J \sigma_{ii} - \sum_{i=1}^J \mu_{i}^2 + \sum_{i=1}^J 1 + \sum_{i=1}^J \log \sigma_{ii} \right) &= \\ \frac{1}{2} \sum_{i=1}^J \left( 1 + \log \sigma_{ii} - \sigma_{ii} - \mu_{i}^2 \right) \end{align}

In the above simplifications, I also applied the following rules.

The official PyTorch implementation of the VAE, which can be found here, also uses this formula. This formula can also be found in Appendix B of the VAE paper, but the long proof that I've just written above is not given. Note that, in my proof above, $\sigma$ is the variance and is denoted by $\sigma^2$ in the paper (as it is usually the case to denote the variance as the square of the standard deviation $\sigma$, but, again, in my proof above $\sigma$ is the variance).

nbro
  • 39,006
  • 12
  • 98
  • 176
6

The code is correct. Since OP asked for a proof, one follows.

The usage in the code is straightforward if you observe that the authors are using the symbols unconventionally: sigma is the natural logarithm of the variance, where usually a normal distribution is characterized in terms of a mean $\mu$ and variance. Some of the functions in OP's link even have arguments named log_var.$^*$

If you're not sure how to derive the standard expression for KL Divergence in this case, you can start from the definition of KL divergence and crank through the arithmetic. In this case, $p$ is the normal distribution given by the encoder and $q$ is the standard normal distribution. $$\begin{align} D_\text{KL}(P \| Q) &= \int_{-\infty}^{\infty} p(x) \log\left(\frac{p(x)}{q(x)}\right) dx \\ &= \int_{-\infty}^{\infty} p(x) \log(p(x)) dx - \int_{-\infty}^{\infty} p(x) \log(q(x)) dx \end{align}$$ The first integral is recognizable as almost definition of entropy of a Gaussian (up to a change of sign). $$ \int_{-\infty}^{\infty} p(x) \log(p(x)) dx = -\frac{1}{2}\left(1 + \log(2\pi\sigma_1^2) \right) $$ The second one is more involved. $$ \begin{align} -\int_{-\infty}^{\infty} p(x) \log(q(x)) dx &= \frac{1}{2}\log(2\pi\sigma_2^2) - \int p(x) \left(-\frac{\left(x - \mu_2\right)^2}{2 \sigma_2^2}\right)dx \\ &= \frac{1}{2}\log(2\pi\sigma_2^2) + \frac{\mathbb{E}_{x\sim p}[x^2] - 2 \mathbb{E}_{x\sim p}[x]\mu_2 +\mu_2^2} {2\sigma_2^2} \\ &= \frac{1}{2}\log(2\pi\sigma_2^2) + \frac{\sigma_1^2 + \mu_1^2-2\mu_1\mu_2+\mu_2^2}{2\sigma_2^2} \\ &= \frac{1}{2}\log(2\pi\sigma_2^2) + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} \end{align} $$ The key is recognizing this gives us a sum of several integrals, and each can apply the law of the unconscious statistician. Then we use the fact that $\text{Var}(x)=\mathbb{E}[x^2]-\mathbb{E}[x]^2$. The rest is just rearranging.

Putting it all together: $$ \begin{align} D_\text{KL}(P \| Q) &= -\frac{1}{2}\left(1 + \log(2\pi\sigma_1^2) \right) + \frac{1}{2}\log(2\pi\sigma_2^2) + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} \\ &= \log (\sigma_2) - \log(\sigma_1) + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2} \end{align} $$

In this special case, we know that $q$ is a standard normal, so $$ \begin{align} D_\text{KL}(P \| Q) &= -\log \sigma_1 + \frac{1}{2}\left(\sigma_1^2 + \mu_1^2 - 1 \right) \\ &= - \frac{1}{2}\left(1 + 2\log \sigma_1- \mu_1^2 -\sigma_1^2 \right) \end{align} $$ In the case that we have a $k$-variate normal with diagonal covariance for $p$, and a multivariate normal with covariance $I$, this is the sum of $k$ univariate normal distributions because in this case the distributions are independent.

The code is a correct implementation of this expression because $\log(\sigma_1^2) = 2 \log(\sigma_1)$ and in the code, sigma is the logarithm of the variance.


$^*$The reason that it's convenient to work on the scale of the log-variance is that the log-variance can be any real number, but the variance is constrained to be non-negative by definition. It's easier to perform optimization on the unconstrained scale than it is to work on the constrained scale in $\eta^2$. Also, we want to avoid "round-tripping," where we compute $\exp(y)$ in one step and then $\log(\exp(y))$ in a later step, because this incurs a loss of precision. In any case, autograd takes care of all of the messy details with adjustments to gradients resulting from moving from one scale to another.

Sycorax
  • 453
  • 5
  • 12
  • The part I'm a little confused by is that how do you know the mean of p(x)? Especially in neural network case? Cause it's intractable, or does reducing it down to a multivariate Gauss make it tractable? But if so is the mean and SD based on what the data produces? – user8714896 Feb 17 '21 at 05:12
  • 1
    The latent representation of a VAE emits a mean and (a transformation of) the variance for each input: the $i$th sample has an associated distribution with parameters $(\mu_i, \sigma_i^2)$. In other words, there's a distribution $p_i(x)$ for each of the $i$ inputs. – Sycorax Feb 17 '21 at 05:15