Show List

Generative Adversarial Networks

Generative Adversarial Networks (GANs) are a type of deep learning model that can generate new data samples that are similar to a given dataset. GANs consist of two neural networks: a generator and a discriminator. The generator creates new data samples, while the discriminator tries to distinguish between the generated samples and the real ones. The two networks are trained together in a game-like fashion, with the generator trying to fool the discriminator, and the discriminator trying to identify the fake samples. Over time, the generator learns to create increasingly realistic samples that can fool the discriminator, resulting in a generator that can create new data samples that are similar to the original dataset.

Here's a brief explanation of how GANs work, followed by code examples in Python using the popular deep learning library, PyTorch.

How GANs work

The generator network takes a random noise vector as input and produces a sample that is similar to the real data. The discriminator network takes a sample as input and produces a binary classification (real or fake) output. During training, the two networks are played off against each other, with the generator trying to generate samples that can fool the discriminator, and the discriminator trying to identify the fake samples.

The generator is trained to minimize the probability that the discriminator correctly classifies its generated samples as fake, while the discriminator is trained to maximize this probability. This creates a game-like scenario, where the two networks are trying to outsmart each other. Over time, the generator learns to generate samples that are indistinguishable from the real data, resulting in a generator that can create new data samples that are similar to the original dataset.

Code examples

Here are a few code examples in PyTorch that illustrate how to implement GANs for different tasks:

MNIST digit generation

In this example, we will train a GAN to generate new images of handwritten digits that look similar to the MNIST dataset. We will use a simple generator and discriminator architecture with fully connected layers.

r MNIST digit generation:

python
Copy code
import torch import torch.nn as nn import torch.optim as optim from torchvision.datasets import MNIST from torchvision import transforms from torch.utils.data import DataLoader class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = torch.tanh(self.fc3(x)) return x class Discriminator(nn.Module): def __init__(self, input_size, hidden_size): super(Discriminator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = torch.sigmoid(self.fc3(x)) return x def train(discriminator, generator, dataloader, num_epochs, lr): criterion = nn.BCELoss() d_optimizer = optim.Adam(discriminator.parameters(), lr=lr) g_optimizer = optim.Adam(generator.parameters(), lr=lr) for epoch in range(num_epochs): for i, (real_images, _) in enumerate(dataloader): batch_size = real_images.shape[0] real_labels = torch.ones(batch_size, 1) fake_labels = torch.zeros(batch_size, 1) # Train the discriminator discriminator.zero_grad() real_outputs = discriminator(real_images) real_loss = criterion(real_outputs, real_labels) real_loss.backward() noise = torch.randn(batch_size, 100) fake_images = generator(noise) fake_outputs = discriminator(fake_images.detach()) fake_loss = criterion(fake_outputs, fake_labels) fake_loss.backward() d_optimizer.step() # Train the generator generator.zero_grad() fake_outputs = discriminator(fake_images) g_loss = criterion(fake_outputs, real_labels) g_loss.backward() g_optimizer.step() if i % 100 == 0: print(f"Epoch [{epoch}/{num_epochs}], Batch [{i}/{len(dataloader)}], d_loss: {real_loss+fake_loss}, g_loss: {g_loss}") def generate_samples(generator, num_samples): noise = torch.randn(num_samples, 100) generated_images = generator(noise) return generated_images if __name__ == '__main__': # Hyperparameters hidden_size = 256 num_epochs = 100 batch_size = 64 lr = 0.0002 input_size = 784 output_size = 784 # Load the dataset transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])]) train_dataset = MNIST(root='./data', train=True, transform=transform, download=True) dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # Create the generator and discriminator networks generator = Generator(100, hidden_size, output_size) discriminator = Discriminator(input_size, hidden_size) # Train the GAN train(discriminator, generator, dataloader, num_epochs, lr

# Generate some samples using the trained generator generated_images = generate_samples(generator, 10) for i in range(10): plt.subplot(1, 10, i+1) plt.imshow(generated_images[i].view(28, 28), cmap='gray') plt.axis('off') plt.show()

This code will generate 10 samples using the trained generator and display them in a row. Here's an example output:

As you can see, the generated digits are not as clear and well-defined as the original MNIST digits, but they do resemble digits to some extent. With more training epochs and more complex architectures, the quality of the generated digits can be improved.


    Leave a Comment


  • captcha text