4

The Conditional Variational Autoencoder (CVAE), introduced in the paper Learning Structured Output Representation using Deep Conditional Generative Models (2015), is an extension of Variational Autoencoder (VAE) (2013). In VAEs, we have no control over the data generation process, something problematic if we want to generate some specific data. Say, in MNIST, generate instances of 6.

So far, I have only been able to find CVAEs that can condition to discrete features (classes). Is there a CVAE that allows us to condition to continuous variables, kind of a stochastic predictive model?

nbro
  • 39,006
  • 12
  • 98
  • 176
D1X
  • 141
  • 1

1 Answers1

1

Whether a discrete or continuous class, you can model it the same.

Denote the encoder $q$ and the decoder $p$. Recall the variational autoencoder's goal is to minimize the $KL$ divergence between $q$ and $p$'s posterior. i.e. $\min_{\theta, \phi} \ KL(q(z|x;\theta) || p(z|x; \phi))$ where $\theta$ and $\phi$ parameterize the encoder and decoder respectively. To make this tractable this is generally done by using the Evidence Lower Bound (because it has the same minimum) and parametrizing $q$ with some form of reparametrization trick to make sampling differentiable.

Now your goal is to condition the sampling. In other words you are looking for modeling $p(x|z, c;\phi)$ and in turn will once again require $q(z|x, c; \theta)$. Your goal will now intuitively become once again $\min_{\theta, \phi} \ KL(q(z|x, c;\theta) || p(z|x, c; \phi))$. This is still simply transformed into the ELBO for tractability purposes. In other words your loss becomes $E_q[log \ p(x|z,c)] - KL(q(z|x,c)||p(z|c)$.

Takeaway: Conditioning doesn't change much, just embed your context and inject it both into the encoder and decoder, the fact that its continuous doesn't change anything. For implementation details, normally people just project/normalize and concatenate it somehow to some representation of $x$ in both the decoder/encoder.

nbro
  • 39,006
  • 12
  • 98
  • 176
mshlis
  • 2,349
  • 7
  • 23