Multi-head attention
Many attentions in parallel
Imagine you had to summarize a paragraph and you were only allowed to track one type of relationship — say, “which noun is the subject of which verb.” You’d lose everything about tense, modifiers, references, sentiment. Real reading comprehension requires tracking many relationships at once. The model has the same problem: a single attention computation can only express one pattern of “which token cares about which” per layer. The fix is to run many attention computations side-by-side, each free to learn its own pattern. That’s multi-head attention.
Splitting d_model into heads
Each attention computation is called a head head One independent attention computation. Multi-head splits d_model into N parallel heads, each learning its own pattern. See in glossary → . If and we want 32 heads, we split into 32 chunks of 128. Each head gets its own slice of the query, key, and value vectors, and runs the attention formula independently. The number 128 here is called the head dimension head_dim d_model / num_heads — the dimension each attention head operates in. See in glossary → ( or ).
Concretely:
num_heads = 32
d_model = 4096
head_dim = d_model / num_heads = 128
Each head produces an output of size 128. The 32 outputs are concatenated back to a single vector of size 4096, then passed through one more learned projection () to mix the heads’ outputs.
In code (schematic):
def multi_head_attention(x):
# x: (seq, d_model)
q = (x @ W_Q).reshape(seq, num_heads, head_dim)
k = (x @ W_K).reshape(seq, num_heads, head_dim)
v = (x @ W_V).reshape(seq, num_heads, head_dim)
# heads run in parallel — same formula, just over the last axis per head
scores = einsum("shd, thd -> sht", q, k) / sqrt(head_dim)
scores = scores.masked_fill(causal_mask, -inf)
weights = softmax(scores, dim=-1)
out = einsum("sht, thd -> shd", weights, v)
# concat heads, project
out = out.reshape(seq, d_model)
return out @ W_O
What different heads actually do
The remarkable empirical finding is that, after training, different heads end up specializing in genuinely different relationships. Researchers have catalogued patterns like:
- Previous-token heads — attend to position . Useful for tracking local syntax.
- Same-word-class heads — nouns attend to other nouns; verbs to other verbs.
- Coreference heads — pronouns (“he”, “she”, “it”) attend to the noun they refer to.
- Punctuation heads — track sentence and clause boundaries.
- Induction heads — implement the pattern “if A B has appeared earlier, and we just saw A again, attend to the position right after the earlier A.” This pattern is responsible for a lot of in-context learning.
- Attention sinks — heads that dump probability on token 0 or other “low-information” tokens, acting as pressure-relief valves when nothing in the context is relevant.
You can switch between a few of these patterns in the heatmap from the previous section. None of this specialization is hand-coded; the model just discovers it because dedicating different heads to different relationships is the most parameter-efficient way to do well on next-token prediction.
How big does this get?
Take Llama-3-8B: 32 layers, 32 attention heads per layer, head_dim = 128. Each layer has four square matrices (Q, K, V, O), each 4096×4096 = 16.7M parameters, for ~67M params per layer in attention alone. Over 32 layers that’s around 2 billion parameters just in attention — about a quarter of the whole model. The MLP blocks (next section) take most of the rest.
We’ve only described attention here in its “vanilla” form, where every head has its own query, key, and value projection. Modern models use a memory-saving variant called grouped-query attention (GQA) that has many fewer KV heads than Q heads, sharing keys and values across query heads. We’ll come back to that in section 12, because it’s a KV-cache optimization more than an attention idea — but it’s the reason the cache fits at all in long-context models.
Attention is the “how do tokens look at each other” part of a transformer. But there’s something we glossed over at the start of section 4: the math you’ve just seen treats its inputs as a bag of vectors. Permute the input tokens and you get a permuted output — the cat and the mat are interchangeable. Real language clearly cares about order, so we need to tell the model where each token sits in the sequence. That’s the topic of the next section: positional encoding.