3

I'm working on understanding VAEs, mostly through video lectures of Stanford cs231n, in particular lecture 13 tackles on this topic and I think I have a good theoretical grasp.

However, when looking at actual code of implementations, such as this code from this blog of VAEs I see some differences which I can't quite understand.

Please take a look at this VAE architecture visualization from the class, specifically the decoder part. From the way it is presented here I understand that the decoder network outputs mean and covariance for the data distribution. To get an actual output (i.e. image) we need to sample from the distribution that is parametrized by mean and covariance - the outputs of the decoder.

Now if you look at the code from the Keras blog VAE implementation, you will see that there is no such thing. A decoder takes in a sample from latent space and directly maps its input (sampled z) to an output (e.g. image), not to parameters of a distribution from which an output is to be sampled.

Am I missing something or does this implementation not correspond to the one presented in the lecture? I've been trying to make sense of it for quite some time now but still can't seem to understand the discrepancy.

nbro
  • 39,006
  • 12
  • 98
  • 176
ytolochko
  • 365
  • 2
  • 5
  • Hi, sorry I answered incorrectly previously but you can check out my modified answer. Hope it removes your ambiguity! –  Apr 02 '19 at 09:04

2 Answers2

0

Thanks @nbro for pointing this out.

The pictorial architecture in the slides uses the Gaussian loss, which when coupled with Maximum Likelihood Estimation gives the squared error loss (not to remove any tractability issues). The main reason we do the encoder Gaussian trick is to force the latent variable $z$ to be normal so that we can apply $KL$ $Divergence$ to optimise an otherwise intractable integral. You can get a better intuition and reasoning in this video.

The pictorial architecture is basically taking the Gaussian loss so that the final loss becomes the squared error loss effectively. Also the loss term used in your blog link is exactly the same loss term used in the original paper, but the blog is using CE loss (it is the more common loss used for classification). I am not sure how they are using the CE loss though, since it is only valid for $0$ and $1$ values and AFAIK MNIST data-set has grayscale images.

I am not exactly sure how they implement the randomness of the Gaussian loss in the decoder structure, but in the simplest of cases they just take the MSE

Check out this blog on VAE's (where they have taken the mean $\Sigma$ which they have abbreviated as mean, I have not checked their implementation detail to know what they exactly mean by that) and also this answer on Data Science on implementation of VAE's4 (both of which gives a more general form of loss). Also for the exact Mathematics check out Appendix C of the original paper.

  • The paper that introduced VAE states that _the decoding term $\log p_\theta(x^{(i)} \mid z^{(i, l)})$, is a Bernoulli or Gaussian MLP, depending on the type of data we are modelling_. I think you should read it in order to confirm your statements (because some of them I think are not correct, and this is the accepted answer). – nbro Apr 02 '19 at 08:34
  • I think that, in the implementation, the decoder just implements the desired probability distribution (in the case of Guassian, with some specific mean and variance, so you do not sample the mean and variance before sampling the image), that is, a function (which is what NNs are good for). – nbro Apr 02 '19 at 08:39
  • @nbro it seems like that. All the blogs are using the CE loss, but the authors have said we can use both CE loss or Gaussian loss (I was unfamilar such a loss existed). But I'll try to reflect it in my answer. Thanks for the heads up! –  Apr 02 '19 at 08:43
0

The VAE architecture from the cs231n class is just a more general version of the code Keras provides, in which the covariance matrix is $\mathbf 0$. You can see this from the reparametrization trick $$ \begin{align} x&=\mu+\Sigma\epsilon\\ &=\mu&\mathrm{if}\ \Sigma=0 \end{align} $$

Maybe
  • 441
  • 2
  • 11