In deep learning, generative adversarial networks (GANs) are a type of generative model architecture. The main idea is that we train two models. One is discriminator model to distinguish between real and fake images, and a generator model that tries to fool the discriminator by generating real and fake images.
Minmax: the discriminator does the best it can. The generator tries to make the discriminator as wrong as possible.
The discriminator takes an image input and outputs a binary label. We learn weights that maximise the probability it correctly labels real/fake images, with binary cross entropy as our loss function.
The generator takes a noise vector as an input and outputs a generated image. We learn weights that maximise the probability the discriminator labels a generated image as real. For our loss function, we use the discriminator itself.
Variations and applications
Other than generating images (which GANs are very good at1) we can also convert monochrome images to colour (wow!). We can also conditionally output only specific classes (with a Boolean vector of classes).
A CycleGAN is a type of GAN that uses a cycle consistency loss to translate images between two domains (like two classes, i.e., night and day images). The key idea is that if we translate an image from A to B and back again, we should get the original image back. We have one GAN for each domain.
Limitations
If the discriminator is too good, then the generator won’t learn. Even then, small changes in the generator weights won’t change the discriminator output, so we can’t perform gradient descent at all. On the other hand, we want the generator to generate a variety of outputs. If it produces the same output (or a small subset of the outputs), the best strategy for the discriminator is to reject that output. If it’s trapped in a local optimum, it can’t adapt to the generator.
What this means is that training vanilla GANs is very difficult, and numerically/visually it’s difficult to see if there’s progress (i.e., a training curve). It also takes a shit ton of time to train. To mitigate the time, we use:
- LeakyReLU activations instead of ReLU
- Batch normalisation
- And regularising the discriminator weights and adding noise to discriminator inputs
In code
In PyTorch, we can describe each model as a separate nn.Module
.
Footnotes
-
See this paper by Nvidia scientists in 2018. ↩