Show List

Transfer Learning

Transfer Learning is a technique where a pre-trained model on one task is used as a starting point for another related task. The idea is that the pre-trained model has already learned some useful features from a large dataset, and can be fine-tuned on a smaller target dataset to improve its performance on the new task.

Here's an example of how to use transfer learning with Python and the PyTorch library:

python
Copy code
import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms # Define the transforms for the training and test datasets transform_train = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Load the pre-trained model and replace the last layer model = torchvision.models.resnet18(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 10) # Define the loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Load the training and test datasets trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=2) # Train the model for epoch in range(5): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() # Use the pre-trained model to extract features outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 100 == 99: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100)) running_loss = 0.0 # Test the model correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

In this example, we're using the CIFAR10 dataset, which consists of 60,000 32x32 color images in 10 classes. We first define transforms for the training and test datasets, which include data augmentation and normalization.

We then load a pre-trained resnet18 model from the torchvision library, which has been trained on the ImageNet dataset. We replace the last layer of the model with a new fully connected layer that has 10 output classes to match the CIFAR10 dataset. This new layer is randomly initialized and will be trained from scratch.

Next, we define the loss function and optimizer, and load the training and test datasets using PyTorch's DataLoader class.

During training, we use the pre-trained model to extract features from the input images, and pass these features through the new fully connected layer to predict the output class. We compute the loss between the predicted and true class labels, and update the model parameters using backpropagation.

After training is complete, we test the model on the test dataset and compute the classification accuracy.

By using transfer learning, we were able to achieve a high classification accuracy on the CIFAR10 dataset with only a small amount of additional training, because the pre-trained model had already learned useful features from a large dataset (ImageNet). This approach can be particularly useful when you have limited training data for your target task, but a pre-trained model is available for a related task.


    Leave a Comment


  • captcha text