Gemma 2
Distillation as a pre-training objective
Paper: Gemma 2: Improving Open Language Models at a Practical Size — Gemma Team, 2024
Gemma 2 (Gemma Team, Google DeepMind, 2024) is a family of small open models (2B, 9B, 27B) whose most interesting pre-training idea isn’t about data or scale — it’s about what the model trains against. For its smaller models, Gemma 2 replaces the one-hot next-token target with the soft predictions of a much larger teacher. That’s knowledge distillation, used as a pre-training objective.
Distillation: a richer target than one-hot
Recall the cross-entropy objective: the target is a one-hot vector one-hot vector A vector that is 1 at a single index and 0 everywhere else. The target next token is represented as a one-hot over the vocabulary; cross-entropy compares the model's distribution against it. See in glossary → — all probability on the single correct next token, zero on everything else. Knowledge distillation knowledge distillation Training a smaller "student" model to match the full output probability distribution of a larger "teacher" model, rather than just the one-hot next token. Richer targets let the student learn more per token. See in glossary → changes the target. A large, already-trained teacher model produces a full probability distribution over the next token, and the student is trained to match that distribution instead of (or alongside) the hard label.
This is a genuinely different answer to the data-scarcity problem than Qwen’s “get more tokens”: instead of more data, get richer targets from a model that already learned from lots of data.
Two architecture efficiency tricks
Gemma 2 also brings two changes worth adding to our running list of modern techniques:
- Interleaved local/global attention. Rather than every layer attending over the full sequence, Gemma 2 alternates sliding-window sliding-window attention Restricting attention to a fixed-size window of nearby tokens instead of the whole sequence. Cheaper and smaller-KV than global attention; modern models interleave local (windowed) and global layers. See in glossary → (local) attention layers with occasional global layers. Local layers only attend to a fixed window of nearby tokens, which is much cheaper and shrinks the KV cache, while the periodic global layers preserve long-range information. This local/global interleaving becomes a defining feature of the Gemma line.
- Logit soft-capping logit soft-capping Bounding the model's logits (and/or attention scores) with a scaled tanh so they can't grow without limit, improving training stability. Used in Gemma 2. See in glossary → . The model bounds its logits (and attention scores) with a scaled so they can’t grow without limit, which improves training numerical stability — a small regularizing touch in the same spirit as gradient clipping.
Both sit alongside the now-standard GQA GQA Grouped-Query Attention — multiple query heads share one K/V head, shrinking the KV cache by 4–8× with minimal quality loss. See in glossary → and RMSNorm RMSNorm Root Mean Square Normalization — a normalization layer that divides each activation by the root-mean-square (√(mean(x²))) of the whole vector, then multiplies by a learned per-dimension scale. Cheaper than LayerNorm (no mean subtraction, no learned bias) and empirically just as good. Standard in Llama-class models. See in glossary → .