back to writing

Transformer from Scratch

Reimplementing “Attention Is All You Need” in raw PyTorch — what finally clicked, and what I got wrong the first three times.

Published
April 2025
Reading time
~12 min
Stack
PyTorch · Python

I'd read the paper a dozen times. I'd used HuggingFace and shipped real models. But I'd never sat down and written every line — embeddings, attention, masking, the encoder, the decoder, the training loop — without copying anything. So I did. This article is what I'd hand to past-me to skip the worst of the confusion.

Why bother

You don't learn a transformer by reading about it. You learn it by debugging a shape mismatch at 1am and realising your attention mask is broadcasting the wrong way. Doing the implementation end-to-end forces every fuzzy mental model to commit to a tensor shape.

The pieces, in order

  1. Token + positional embeddings. The smallest piece and the easiest to get wrong. The positional encoding is just a fixed sinusoidal pattern added to the token embedding — but it's the choice that lets attention work on un-ordered sets.
  2. Scaled dot-product attention. Three projections — Q, K, V — and the famous softmax(QK^T / sqrt(d_k)) V. The sqrt(d_k) isn't cosmetic; without it, large dimensions push softmax into vanishing-gradient territory.
  3. Multi-head attention. Split the projections into h heads, attend in parallel, concat. The cheapest way to let the model look at multiple relations at once.
  4. Position-wise feed-forward. Two linear layers and a ReLU. Surprisingly, most of the parameter count lives here, not in attention.
  5. Encoder + decoder stacks. Residual connections around every sub-layer, layer norm before (or after) — the order matters more than you'd think.

What I got wrong the first three times

  • Mask shape. The causal mask in the decoder is (seq, seq), not (batch, seq, seq). Let broadcasting do the work.
  • Padding mask in attention. You mask before softmax, not after. Masking after gives you a distribution that doesn't sum to one and breaks gradients silently.
  • Weight tying. Sharing the output projection with the input embedding matrix is a one-line change that cuts parameters and improves perplexity. The paper mentions it; almost every tutorial skips it.
  • Warmup matters. The Noam schedule (linear warmup, then 1/sqrt(step) decay) isn't a hyperparameter you tune — it's load-bearing. Train without it and the loss goes flat.

The shape cheat-sheet

Carry this in your head and most of the work disappears:

  • x: (B, S, D) — batch, sequence length, model dim.
  • q, k, v: (B, H, S, D/H) — split across heads.
  • scores: (B, H, S, S) — attention scores per head.
  • mask: (S, S) or (B, 1, 1, S) — broadcasts over what it doesn't cover.

What I'd do differently next time

  • Start with the inference path, not the training loop. Get a forward pass working on dummy tensors before touching loss.
  • Write the attention block as a pure function first, wrap it in nn.Module only once shapes are nailed down.
  • Sanity-check on a tiny copy-task before scaling up. If a 2-layer transformer can't learn to copy a sequence in 1000 steps, something's wrong with your masking.

The code

The full implementation lives on GitHub. It's deliberately small — one file per concept, no abstraction for its own sake. Read it top-to-bottom and you'll have a working transformer in your head.