Back to Notes

Understanding Attention Mechanisms in Transformers

2026-01-10Deep Learning
Deep LearningTransformersNLPAttention

Understanding Attention Mechanisms

Attention mechanisms have revolutionized the field of deep learning, particularly in natural language processing. In this note, we'll dive deep into how attention works.

Self-Attention

The self-attention mechanism allows each position in the sequence to attend to all positions in the same sequence. The key formula is:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Where:

  • QQ (Query): The query matrix
  • KK (Key): The key matrix
  • VV (Value): The value matrix
  • dkd_k: The dimension of the key vectors

Multi-Head Attention

Multi-head attention extends self-attention by computing attention in parallel across multiple "heads":

MultiHead(Q,K,V)=Concat(head1,...,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O

where each head is computed as:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

PyTorch Implementation

Here's a simple implementation of multi-head attention in PyTorch:

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply softmax and compute output
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        # Reshape and apply final linear projection
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(output)

Key Takeaways

  1. Parallelization: Unlike RNNs, attention allows parallel processing of all positions
  2. Long-range dependencies: Direct connections between any two positions
  3. Interpretability: Attention weights can reveal what the model focuses on

Note: The scaling factor dk\sqrt{d_k} prevents the dot products from growing too large, which would push the softmax into regions with very small gradients.

Further Reading