// article · ml fundamentals
Transformer from Scratch
Reimplementing “Attention Is All You Need” in raw PyTorch — what finally clicked, and what I got wrong the first three times.
- Published
- April 2025
- Reading time
- ~12 min
- Stack
- PyTorch · Python
- Code
- GitHub
I'd read the paper a dozen times. I'd used HuggingFace and shipped real models. But I'd never sat down and written every line — embeddings, attention, masking, the encoder, the decoder, the training loop — without copying anything. So I did. This article is what I'd hand to past-me to skip the worst of the confusion.
Why bother
You don't learn a transformer by reading about it. You learn it by debugging a shape mismatch at 1am and realising your attention mask is broadcasting the wrong way. Doing the implementation end-to-end forces every fuzzy mental model to commit to a tensor shape.
The pieces, in order
- Token + positional embeddings. The smallest piece and the easiest to get wrong. The positional encoding is just a fixed sinusoidal pattern added to the token embedding — but it's the choice that lets attention work on un-ordered sets.
- Scaled dot-product attention. Three projections — Q, K, V — and the famous
softmax(QK^T / sqrt(d_k)) V. Thesqrt(d_k)isn't cosmetic; without it, large dimensions push softmax into vanishing-gradient territory. - Multi-head attention. Split the projections into h heads, attend in parallel, concat. The cheapest way to let the model look at multiple relations at once.
- Position-wise feed-forward. Two linear layers and a ReLU. Surprisingly, most of the parameter count lives here, not in attention.
- Encoder + decoder stacks. Residual connections around every sub-layer, layer norm before (or after) — the order matters more than you'd think.
What I got wrong the first three times
- Mask shape. The causal mask in the decoder is
(seq, seq), not(batch, seq, seq). Let broadcasting do the work. - Padding mask in attention. You mask before softmax, not after. Masking after gives you a distribution that doesn't sum to one and breaks gradients silently.
- Weight tying. Sharing the output projection with the input embedding matrix is a one-line change that cuts parameters and improves perplexity. The paper mentions it; almost every tutorial skips it.
- Warmup matters. The Noam schedule (linear warmup, then 1/sqrt(step) decay) isn't a hyperparameter you tune — it's load-bearing. Train without it and the loss goes flat.
The shape cheat-sheet
Carry this in your head and most of the work disappears:
x: (B, S, D)— batch, sequence length, model dim.q, k, v: (B, H, S, D/H)— split across heads.scores: (B, H, S, S)— attention scores per head.mask: (S, S) or (B, 1, 1, S)— broadcasts over what it doesn't cover.
What I'd do differently next time
- Start with the inference path, not the training loop. Get a forward pass working on dummy tensors before touching loss.
- Write the attention block as a pure function first, wrap it in
nn.Moduleonly once shapes are nailed down. - Sanity-check on a tiny copy-task before scaling up. If a 2-layer transformer can't learn to copy a sequence in 1000 steps, something's wrong with your masking.
The code
The full implementation lives on GitHub. It's deliberately small — one file per concept, no abstraction for its own sake. Read it top-to-bottom and you'll have a working transformer in your head.