Transfer Learning
Transfer learning is a technique in machine learning where a pre-trained model on a large dataset is used as a starting point for training on a new, smaller dataset. This approach is particularly useful when the new dataset is small and lacks sufficient data to train a complex model from scratch.
Here's an example of using transfer learning in Python using the Keras API:
from keras.applications import VGG16
from keras.models import Model
from keras.layers import Dense, Flatten
from keras.optimizers import SGD
# load the pre-trained model (VGG16)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# add new classification layers on top
x = base_model.output
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)
# define the full model
model = Model(inputs=base_model.input, outputs=predictions)
# freeze the base model layers
for layer in base_model.layers:
layer.trainable = False
# compile the model
optimizer = SGD(lr=0.001, momentum=0.9)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# train the model on the new dataset
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_val, y_val))
# fine-tune the last few layers of the base model
for layer in model.layers[:10]:
layer.trainable = False
for layer in model.layers[10:]:
layer.trainable = True
# recompile the model with a lower learning rate
optimizer = SGD(lr=0.0001, momentum=0.9)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# fine-tune the model on the new dataset
model.fit(X_train, y_train, epochs=10, batch_size=32, validation_data=(X_val, y_val))
In this example, we're using the VGG16 pre-trained model on the ImageNet dataset as the starting point for our transfer learning model. We load the pre-trained model using the VGG16
class in Keras, and then add new classification layers on top to adapt the model for our new dataset.
We freeze the weights of the base model layers to prevent them from being updated during training of the new layers. We then compile the model and train it on the new dataset.
After training, we fine-tune the last few layers of the base model by unfreezing them and lowering the learning rate. We recompile the model with the new learning rate and train it again on the new dataset.
Transfer learning can be used in many different contexts and with different pre-trained models. This example is just a starting point to understand the basics of using transfer learning in Keras. Depending on the problem you are trying to solve, you may need to modify the architecture and hyperparameters of the model.
Leave a Comment