In deep learning, generative adversarial networks (GANs) are a type of generative model architecture. The main idea is that we train two models. One is discriminator model to distinguish between real and fake images, and a generator model that tries to fool the discriminator by generating real and fake images.

Minmax: the discriminator does the best it can. The generator tries to make the discriminator as wrong as possible.

The discriminator takes an image input and outputs a binary label. We learn weights that maximise the probability it correctly labels real/fake images, with binary cross entropy as our loss function.

The generator takes a noise vector as an input and outputs a generated image. We learn weights that maximise the probability the discriminator labels a generated image as real. For our loss function, we use the discriminator itself.

Variations and applications

Other than generating images (which GANs are very good at1) we can also convert monochrome images to colour (wow!). We can also conditionally output only specific classes (with a Boolean vector of classes).

A CycleGAN is a type of GAN that uses a cycle consistency loss to translate images between two domains (like two classes, i.e., night and day images). The key idea is that if we translate an image from A to B and back again, we should get the original image back. We have one GAN for each domain.

Limitations

If the discriminator is too good, then the generator won’t learn. Even then, small changes in the generator weights won’t change the discriminator output, so we can’t perform gradient descent at all. On the other hand, we want the generator to generate a variety of outputs. If it produces the same output (or a small subset of the outputs), the best strategy for the discriminator is to reject that output. If it’s trapped in a local optimum, it can’t adapt to the generator.

What this means is that training vanilla GANs is very difficult, and numerically/visually it’s difficult to see if there’s progress (i.e., a training curve). It also takes a shit ton of time to train. To mitigate the time, we use:

In code

In PyTorch, we can describe each model as a separate nn.Module.

class Discriminator(nn.Module):
	def __init__(self):
		super(Discriminator, self).__init__()
		self.model = nn.Sequential(
		nn.Linear(28*28, 300),
		nn.LeakyReLU(0.2),
		nn.Linear(300, 100),
		nn.LeakyReLU(0.2),
		nn.Linear(100, 1))
	def forward(self, x):
		x = x.view(x.size(0), -1)
		out = self.model(x)
		return out.view(x.size(0))
 
def train_discrimintor(discriminator, generator, images):
	batch_size = images.size(0)
	noise = torch.randn(batch_size, 100)
	fake_images = generator(noise)
	inputs = torch.cat([images, fake_images])
	labels = torch.cat([torch.zeros(batch_size), # Real
	torch.ones(batch_size)]) # Fake
	outputs = discriminator(inputs)
	loss = criterion(outputs, labels)
	return outputs, loss
class Generator(nn.Module):
	def __init__(self):
		super(Generator, self).__init__()
		self.model = nn.Sequential(
		nn.Linear(100, 300),
		nn.LeakyReLU(0.2),
		nn.Linear(300, 28*28),
		nn.Sigmoid())
	def forward(self, x):
		out = self.model(x).view(x.size(0), 1, 28, 28)
		return out.view(x.size(0))
 
def train_generator(discriminator, generator, batch_size):
	batch_size = images.size(0)
	noise = torch.randn(batch_size, 100)
	fake_images = generator(noise)
	outputs = discriminator(fake_images)
	# Only looks at fake outputs
	# gets rewarded if we fool the discriminator!
	labels = torch.zeros(batch_size)
	loss = criterion(outputs, labels)
	return fake_images, loss

Footnotes

  1. See this paper by Nvidia scientists in 2018.