Attention Mechanism

Draft 2 min read

Intuition

Attention allows a model to focus on relevant parts of the input when producing each part of the output. Instead of compressing an entire sequence into a fixed-size vector, attention computes a weighted sum over all positions.

Scaled Dot-Product Attention

Given queries Q\mathbf{Q}, keys K\mathbf{K}, and values V\mathbf{V}:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}\right)\mathbf{V}

Where dkd_k is the dimension of the keys. The scaling prevents softmax from producing extremely peaked distributions.

Multi-Head Attention

Instead of one attention function, use hh parallel heads:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\mathbf{W}^O

Each head can attend to different representation subspaces.

Self-Attention

When Q, K, V all come from the same sequence, it’s self-attention. Each token attends to every other token in the sequence.

PyTorch Implementation

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        # x shape: (seq_len, batch, embed_dim)
        out, weights = self.attn(x, x, x)
        return out

Why Attention Works

  • Parallelizable: unlike RNNs, all positions are computed simultaneously
  • Long-range dependencies: any position can directly attend to any other
  • Interpretable: attention weights show what the model focuses on