In the training of a Generative Adversarial Networks (GAN) system, a perfect discriminator (D) is one which outputs 1 ("true image") for all images of the training dataset and 0 ("false image") for all images created by the generator (G).
I've read on separate occasions that -when trained with the original GAN framework and loss function as described in the Goodfellow 2014 paper- if this "perfect D state" occurs during training then this is a "failure mode", sometimes refered to as "vanishing gradients" or "convergence failure", from which G "cannot recover".
For instance :
"When the discriminator is perfect (...), the loss function falls to zero and we end up with no gradient to update the loss during learning iterations" (source, well-recieved blog post)
"An optimal discriminator doesn't provide enough information for the generator to make progress." (source, Google developers)
[T]he generator score approaches zero and does not recover. (...) In this case, the discriminator classifies most of the images correctly. In turn, the generator cannot produce any images that fool the discriminator and thus fails to learn. (source, Mathworks)
The Towards Principled Methods for training GANs paper (Arjovsky & Bottou, 2017) is perhaps a more accurate source on the matter as it dives deeper in the theory by proving that since the loss function relates to the Jensen-Shannon distance, it saturates when the distributions of real and generated images are disjoint (hence no gradient).
If the two distributions we care about have supports that are disjoint or lie on low dimensional manifolds, the optimal discriminator will be perfect and its gradient will be zero almost everywhere. (section 2.2)
I understand the gist of the theory developed in that paper, but the result seems nevertheless counter-intuitive to me when I try to think about them in the context of the GAN training procedure. Indeed, my undertanding of a GAN training iteration is the following :
- Show true images to D, train D to output 1.
- Show false (created by G) images to D, train D to output 0.
- Show false images to D, train G such that D outputs 1.
I can see that in the case of a perfect D, steps 1 and 2 would lead to a loss of zero and hence no gradient.
But I would expect that in step 3 :
- As a perfect D would predict all images to be 0 and the loss would be computed by comparing its answers to predicting all 1s, we'd have a high loss.
- Thus, we would get a high gradient.
- By backprop, this gradient would lead to the identifcation of the most salient features in the images which D is using to predict them as false.
- This would provide valuable information to G to improve its false images to better match the training set.
So looking at it that way, it doesn't seem to me that a perfect D should lead to "vanishing gradients" and G being unable to recover from it.
What's wrong with my understanding of the training process and why it is not compatible with the results from the Arjovsky paper ?