1

I know I can make a VAE do generation with a mean of 0 and std-dev of 1. I tested it with the following loss function:

def loss(self, data, reconst, mu, sig):
    rl = self.reconLoss(reconst, data)
    #dl = self.divergenceLoss(mu, sig)
    std = torch.exp(0.5 * sig)
    compMeans = torch.full(std.size(), 0.0)
    compStd = torch.full(std.size(), 1.0)
    dl = kld(mu, std, compMeans, compStd)
    totalLoss = self.rw * rl + self.dw * dl
    return (totalLoss, rl, dl)

def kld(mu1, std1, mu2, std2):
    p = torch.distributions.Normal(mu1, std1)
    q = torch.distributions.Normal(mu2, std2)
    return torch.distributions.kl_divergence(p, q).mean()

In this case, mu and sig are from the latent vector, and reconLoss is MSE. This works well, and I am able to generate MNIST digits by feeding in noise from a standard normal distribution.

However, I'd now like to concentrate the distribution at a normal distribution with std-dev of 1 and mean of 10. I tried changing it like this:

compMeans = torch.full(std.size(), 10.0)

I did the same change in reparameterization and generation functions. But what worked for the standard normal distribution is not working for the mean = 10 normal one. Reconstruction still works fine but generation does not, only producing strange shapes. Oddly, the divergence loss is actually going down too, and reaching a similar level to what it reached with standard normal.

Does anyone know why this isn't working? Is there something about KL that does not work with non-standard distributions?

Other things I've tried:

  • Generating from 0,1 after training on 10,1: failed
  • Generating on -10,1 after training on 10,1: failed
  • Custom version of KL divergence: worked on 0,1. failed on 10,1
  • Using sigma directly instead of std = torch.exp(0.5 * sig): failed

Edit 1: Below are my loss plots with 0,1 distribution. Reconstruction: reconst-loss

Divergence: enter image description here

Generation samples: enter image description here

Reconstruction samples (left is input, right is output): enter image description here

And here are the plots for 10,1 normal distribution.

Reconstruction: enter image description here

Divergence: enter image description here

Generation sample: enter image description here

Note: when I ran it this time, it actually seemed to learn the generation a bit, though it's still printing mostly 8's or things that are nearly an 8 by structure. This is not the case for the standard normal distribution. The only difference from last run is the random seed.

Reconstruction sample: enter image description here

Sampled latent:

tensor([[ 9.6411,  9.9796,  9.9829, 10.0024,  9.6115,  9.9056,  9.9095, 10.0684,
         10.0435,  9.9308],
        [ 9.8364, 10.0890,  9.8836, 10.0544,  9.4017, 10.0457, 10.0134,  9.9539,
         10.0986, 10.0434],
        [ 9.9301,  9.9534, 10.0042, 10.1110,  9.8654,  9.4630, 10.0256,  9.9237,
          9.8614,  9.7408],
        [ 9.3332, 10.1289, 10.0212,  9.7660,  9.7731,  9.9771,  9.8550, 10.0152,
          9.9879, 10.1816],
        [10.0605,  9.8872, 10.0057,  9.6858,  9.9998,  9.4429,  9.8378, 10.0389,
          9.9264,  9.8789],
        [10.0931,  9.9347, 10.0870,  9.9941, 10.0001, 10.1102,  9.8260, 10.1521,
          9.9961, 10.0989],
        [ 9.5413,  9.8965,  9.2484,  9.7604,  9.9095,  9.8409,  9.3402,  9.8552,
          9.7309,  9.7300],
        [10.0113,  9.5318,  9.9867,  9.6139,  9.9422, 10.1269,  9.9375,  9.9242,
          9.9532,  9.9053],
        [ 9.8866, 10.1696,  9.9437, 10.0858,  9.5781, 10.1011,  9.8957,  9.9684,
          9.9904,  9.9017],
        [ 9.6977, 10.0545, 10.0383,  9.9647,  9.9738,  9.9795,  9.9165, 10.0705,
          9.9072,  9.9659],
        [ 9.6819, 10.0224, 10.0547,  9.9457,  9.9592,  9.9380,  9.8731, 10.0825,
          9.8949, 10.0187],
        [ 9.6339,  9.9985,  9.7757,  9.4039,  9.7309,  9.8588,  9.7938,  9.8712,
          9.9763, 10.0186],
        [ 9.7688, 10.0575, 10.0515, 10.0153,  9.9782, 10.0115,  9.9269, 10.1228,
          9.9738, 10.0615],
        [ 9.8575,  9.8241,  9.9603, 10.0220,  9.9342,  9.9557, 10.1162, 10.0428,
         10.1363, 10.3070],
        [ 9.6856,  9.7924,  9.9174,  9.5064,  9.8072,  9.7176,  9.7449,  9.7004,
          9.8268,  9.9878],
        [ 9.8630, 10.0470, 10.0227,  9.7871, 10.0410,  9.9470, 10.0638, 10.1259,
         10.1669, 10.1097]])

Note, this does seem to be in the right distribution.

Just in case, here's my reparameterization method too. Currently with 10,1 distribution:

def reparamaterize(self, mu, sig):
        std = torch.exp(0.5 * sig)
        epsMeans = torch.full(std.size(), 10.0)
        epsStd = torch.full(std.size(), 1.0)
        eps = torch.normal(epsMeans, epsStd)
        return eps * std + mu
axon
  • 53
  • 5
  • 1
    If you use a prior of N(10, 1), this should mean that, after training, the value of the elements of the latent vectors should be around 10 (because the variational distribution should be shifted towards the prior in order to minimize the loss): can you confirm whether that is the case or not? I don't remember now all the details of the VAE, i.e. which steps are used to generate the image from the latent vector: can you recall them? Moreover, could you also please tell us the exact magnitude of the KL divergence in both the case of N(10, 1) and N(0, 1)? – nbro Feb 18 '21 at 17:23
  • 1
    The fact that the KL divergence decreases is not necessarily a sign of good performance. The performance typically comes from the MSE part (the KL is there more as a regularizer). Could you please provide some info about that in both cases N(0, 1) and N(10, 1)? – nbro Feb 18 '21 at 17:24
  • I have edited the question with the new information. – axon Feb 19 '21 at 00:17
  • Generating in a VAE involves training it to expect input from the selected distribution, so when you actually generate you just need to get a random tensor from that expected distribution and feed it to the decoder. – axon Feb 19 '21 at 00:19
  • Why is the KL increasing in the case of N(0, 1)? That doesn't look right, it should decrease. – nbro Feb 19 '21 at 13:50
  • It is part of a compound loss function with reconst loss, defined as totalLoss = self.rw * rl + self.dw * dl. Where dw and rw are hyperparameters to scale the two losses. The loss goes up so that reconst loss can go down. – axon Feb 19 '21 at 19:20

0 Answers0