Generative Adversarial Networks
Generative Adversarial Networks (GANs) are a type of deep learning model that can generate new data samples that are similar to a given dataset. GANs consist of two neural networks: a generator and a discriminator. The generator creates new data samples, while the discriminator tries to distinguish between the generated samples and the real ones. The two networks are trained together in a game-like fashion, with the generator trying to fool the discriminator, and the discriminator trying to identify the fake samples. Over time, the generator learns to create increasingly realistic samples that can fool the discriminator, resulting in a generator that can create new data samples that are similar to the original dataset.
Here's a brief explanation of how GANs work, followed by code examples in Python using the popular deep learning library, PyTorch.
How GANs work
The generator network takes a random noise vector as input and produces a sample that is similar to the real data. The discriminator network takes a sample as input and produces a binary classification (real or fake) output. During training, the two networks are played off against each other, with the generator trying to generate samples that can fool the discriminator, and the discriminator trying to identify the fake samples.
The generator is trained to minimize the probability that the discriminator correctly classifies its generated samples as fake, while the discriminator is trained to maximize this probability. This creates a game-like scenario, where the two networks are trying to outsmart each other. Over time, the generator learns to generate samples that are indistinguishable from the real data, resulting in a generator that can create new data samples that are similar to the original dataset.
Code examples
Here are a few code examples in PyTorch that illustrate how to implement GANs for different tasks:
MNIST digit generation
In this example, we will train a GAN to generate new images of handwritten digits that look similar to the MNIST dataset. We will use a simple generator and discriminator architecture with fully connected layers.
r MNIST digit generation:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.tanh(self.fc3(x))
return x
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
def train(discriminator, generator, dataloader, num_epochs, lr):
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = optim.Adam(generator.parameters(), lr=lr)
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(dataloader):
batch_size = real_images.shape[0]
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# Train the discriminator
discriminator.zero_grad()
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
real_loss.backward()
noise = torch.randn(batch_size, 100)
fake_images = generator(noise)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
fake_loss.backward()
d_optimizer.step()
# Train the generator
generator.zero_grad()
fake_outputs = discriminator(fake_images)
g_loss = criterion(fake_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
if i % 100 == 0:
print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(dataloader)}], d_loss: {real_loss+fake_loss}, g_loss: {g_loss}")
def generate_samples(generator, num_samples):
noise = torch.randn(num_samples, 100)
generated_images = generator(noise)
return generated_images
if __name__ == '__main__':
# Hyperparameters
hidden_size = 256
num_epochs = 100
batch_size = 64
lr = 0.0002
input_size = 784
output_size = 784
# Load the dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Create the generator and discriminator networks
generator = Generator(100, hidden_size, output_size)
discriminator = Discriminator(input_size, hidden_size)
# Train the GAN
train(discriminator, generator, dataloader, num_epochs, lr
# Generate some samples using the trained generator
generated_images = generate_samples(generator, 10)
for i in range(10):
plt.subplot(1, 10, i+1)
plt.imshow(generated_images[i].view(28, 28), cmap='gray')
plt.axis('off')
plt.show()
This code will generate 10 samples using the trained generator and display them in a row. Here's an example output:
As you can see, the generated digits are not as clear and well-defined as the original MNIST digits, but they do resemble digits to some extent. With more training epochs and more complex architectures, the quality of the generated digits can be improved.
Leave a Comment