Show List

Multi-Head Attention in Transformers

Multi-Head Attention is a key component of Transformers, a deep learning architecture commonly used for natural language processing tasks such as machine translation, text classification, and sentiment analysis. Multi-Head Attention allows the model to attend to different parts of the input sequence, enabling it to learn more complex and flexible relationships between the input and output.

To understand Multi-Head Attention, it is helpful to first understand the concept of self-attention. Self-attention is a mechanism that allows a model to attend to different parts of the input sequence when making predictions about a particular part of the sequence. It does this by creating a weighted representation of the input sequence, where each element is weighted according to its relevance to the current prediction.

In a self-attention mechanism, the input sequence is first transformed into three different vectors: the query vector, the key vector, and the value vector. These vectors are then used to compute the attention weights for each element in the input sequence.

Multi-Head Attention extends this idea by computing multiple sets of query, key, and value vectors, and then combining the results. This allows the model to attend to multiple parts of the input sequence simultaneously, and to learn more complex and flexible relationships between the input and output.

Here is an example implementation of Multi-Head Attention in PyTorch:

python
Copy code
import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, nhead): super().__init__() self.embed_dim = embed_dim self.nhead = nhead self.query_proj = nn.Linear(embed_dim, embed_dim) self.key_proj = nn.Linear(embed_dim, embed_dim) self.value_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, x): batch_size, seq_len, embed_dim = x.size() # Project the input to the query, key, and value vectors query = self.query_proj(x).view(batch_size, seq_len, self.nhead, embed_dim // self.nhead).transpose(1, 2) key = self.key_proj(x).view(batch_size, seq_len, self.nhead, embed_dim // self.nhead).transpose(1, 2) value = self.value_proj(x).view(batch_size, seq_len, self.nhead, embed_dim // self.nhead).transpose(1, 2) # Compute the attention weights and the output attn_weights = torch.matmul(query, key.transpose(-2, -1)) / (embed_dim ** 0.5) attn_weights = nn.Softmax(dim=-1)(attn_weights) attn_output = torch.matmul(attn_weights, value).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # Project the output to the output dimension output = self.out_proj(attn_output) return output

In this code, the MultiHeadAttention class is defined with an __init__ method and a forward method. The __init__ method initializes the various linear transformations used in the computation, and the forward method performs the actual computation.

The forward method first computes the query, key, and value vectors for the input sequence, by applying linear transformations to the input. It then reshapes these vectors and applies a matrix multiplication to compute the attention weights and the output. The attention weights are computed by taking the dot product between the query and key vectors, and normalizing the result using a Softmax function. The output is computed by taking the dot product between the attention weights and the value vectors. Finally, the output is reshaped and projected to the output dimension using another linear transformation.

This implementation can be used as part of a larger Transformer model for natural language processing tasks. By allowing the model to attend to different parts of the input sequence with multiple sets of query, key, and value vectors, Multi-Head Attention enables the model to learn more complex and flexible relationships between the input and output. This can result in higher accuracy and better performance on natural language processing tasks.


    Leave a Comment


  • captcha text