Few more clarifications. While the correct thing to do is draw from the prior, we have no guarantees that the aggregated posterior will cover the prior. Think of the aggregated posterior as the distribution of the latent variables for your dataset (see here for a nice explanation and visualization). Our hope is that this will be like the prior but often in practice we get a mismatch between the prior and the aggregate posterior. In this case sampling from the prior might fail because part of it is not covered by the aggregate posterior. This can be solved in various ways, like learning the prior or computing the aggregated posterior after training.
Maybe there's a misconception, we are not learning a mu
and log_var
but a mapping (encoder) from an image to mu
and log_var
. This is quite different because the mu
and log_var
are not two fixed vectors for the dataset but are computed separately for each image.
In similar fashion, the decoder is a learned mapping from the prior distribution $N(0,I)$ back to the image space.
Essentially the encoder takes the image as input and spits out the parameters of another gaussian (the posterior). This means that during training the input of the decoder is conditioned upon the image. Let's take MNIST for example. We hope that after the training the encoder has learned to spit out similar mu
and log_var
for similar digits and that the decoder has learned to decode noise from a posterior to a specific digit.
For example with a 1-dimensional latent what we hope for is something like this:
Input digit 0 --> Encoder gives mu 0.1 log_var 0.3
Input digit 0 --> Encoder gives mu 0.2 log_var 0.2
Input digit 1 --> Encoder gives mu 1.4 log_var 0.2
Input digit 1 --> Encoder gives mu 1.5 log_var 0.1
...
Input digit 9 --> Encoder gives mu -4.5 log_var 0.3
This blogpost has a nice visualization with 2d latents.
If we didn't have the encoder, we would always draw noise from the same N(0,I)
gaussian. This could also work but then we'd need a different training technique like in GANs.
During test time we many times want to draw a sample from the whole data distribution and for that reason we should use the prior $N(0,I)$. If you for some reason want to condition the output to look like a specific sample then you can use the posterior. For example if you only want digits of 1 then you can pass an image of 1 through the encoder and then use the mu
, log_var
to draw samples.
So the questions is, do you want a sample from the whole distribution? Then use the prior.