Multi-Head Attention in Transformers
Multi-Head Attention is a key component of Transformers, a deep learning architecture commonly used for natural language processing tasks such as machine translation, text classification, and sentiment analysis. Multi-Head Attention allows the model to attend to different parts of the input sequence, enabling it to learn more complex and flexible relationships between the input and output.
To understand Multi-Head Attention, it is helpful to first understand the concept of self-attention. Self-attention is a mechanism that allows a model to attend to different parts of the input sequence when making predictions about a particular part of the sequence. It does this by creating a weighted representation of the input sequence, where each element is weighted according to its relevance to the current prediction.
In a self-attention mechanism, the input sequence is first transformed into three different vectors: the query vector, the key vector, and the value vector. These vectors are then used to compute the attention weights for each element in the input sequence.
Multi-Head Attention extends this idea by computing multiple sets of query, key, and value vectors, and then combining the results. This allows the model to attend to multiple parts of the input sequence simultaneously, and to learn more complex and flexible relationships between the input and output.
Here is an example implementation of Multi-Head Attention in PyTorch:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, nhead):
super().__init__()
self.embed_dim = embed_dim
self.nhead = nhead
self.query_proj = nn.Linear(embed_dim, embed_dim)
self.key_proj = nn.Linear(embed_dim, embed_dim)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, x):
batch_size, seq_len, embed_dim = x.size()
# Project the input to the query, key, and value vectors
query = self.query_proj(x).view(batch_size, seq_len, self.nhead, embed_dim // self.nhead).transpose(1, 2)
key = self.key_proj(x).view(batch_size, seq_len, self.nhead, embed_dim // self.nhead).transpose(1, 2)
value = self.value_proj(x).view(batch_size, seq_len, self.nhead, embed_dim // self.nhead).transpose(1, 2)
# Compute the attention weights and the output
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (embed_dim ** 0.5)
attn_weights = nn.Softmax(dim=-1)(attn_weights)
attn_output = torch.matmul(attn_weights, value).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
# Project the output to the output dimension
output = self.out_proj(attn_output)
return output
In this code, the MultiHeadAttention
class is defined with an __init__
method and a forward
method. The __init__
method initializes the various linear transformations used in the computation, and the forward
method performs the actual computation.
The forward
method first computes the query, key, and value vectors for the input sequence, by applying linear transformations to the input. It then reshapes these vectors and applies a matrix multiplication to compute the attention weights and the output. The attention weights are computed by taking the dot product between the query and key vectors, and normalizing the result using a Softmax function. The output is computed by taking the dot product between the attention weights and the value vectors. Finally, the output is reshaped and projected to the output dimension using another linear transformation.
This implementation can be used as part of a larger Transformer model for natural language processing tasks. By allowing the model to attend to different parts of the input sequence with multiple sets of query, key, and value vectors, Multi-Head Attention enables the model to learn more complex and flexible relationships between the input and output. This can result in higher accuracy and better performance on natural language processing tasks.
Leave a Comment