Lab 2 Walkthrough โ€” Multi-Head Attention & Transformer Decoder

Goal: In this lab, you will implement the core building blocks of a Transformer โ€” first at the C++ level (custom matrix math and attention kernel), then at the PyTorch level (modular deep learning architecture).

By the end, youโ€™ll understand both the numerical flow and the architectural flow of attention.


๐Ÿงฉ Task 1 โ€” Matrix Multiplication (matmul)

๐ŸŽฏ Goal

Implement C = A ร— B + bias in row-major order.

This forms the foundation for Q, K, V projections later โ€” every linear layer in your Transformer depends on it.


๐Ÿ” Understanding the math

If

AโˆˆRd0ร—d1,BโˆˆRd1ร—d2A \in \mathbb{R}^{d_0 \times d_1}, \quad B \in \mathbb{R}^{d_1 \times d_2}

then

C[i,j]=โˆ‘k=0d1โˆ’1A[i,k]ร—B[k,j]+bias[j]C[i, j] = \sum_{k=0}^{d_1-1} A[i, k] \times B[k, j] + \text{bias}[j]

๐Ÿง  Step-by-Step

  1. Initialize C

    • If bias exists โ†’ start each C[i,j] = bias[j]
    • Else โ†’ start from 0.0.
  2. Compute dot products

    cpp
    for (unsigned i = 0; i < d0; ++i)
      for (unsigned j = 0; j < d2; ++j)
        for (unsigned k = 0; k < d1; ++k)
          C[i*d2 + j] += A[i*d1 + k] * B[k*d2 + j];
    
  3. Check indexing carefully Remember: row-major means elements of a row are contiguous in memory.


โ“Common Questions

Q: Why do we add the bias per column? A: Because each output neuron (column) has its own bias term โ€” it shifts all rows by the same amount.

Q: Why three nested loops? A: Weโ€™re performing a full matrix product; hardware accelerators can parallelize it, but logically itโ€™s a triple loop.


โœ… Checkpoint

Compare your result against NumPy or PyTorch:

Python
torch.allclose(torch.tensor(C), A @ B + bias, atol=1e-5)

๐Ÿงฎ Task 2 โ€” Row-Wise Softmax

๐ŸŽฏ Goal

Convert each row of logits into probabilities that sum to 1.


๐Ÿ” The math

For each row ( i ):

pi,j=eAi,jโˆ’maxโกjAi,jโˆ‘keAi,kโˆ’maxโกjAi,jp_{i,j} = \frac{e^{A_{i,j} - \max_j A_{i,j}}}{\sum_k e^{A_{i,k} - \max_j A_{i,j}}}

Subtracting the row max prevents overflow in the exponential.


๐Ÿง  Implementation Steps

  1. Find max per row

    cpp
    float maxv = A[i*d1];
    for (unsigned j = 1; j < d1; ++j)
      maxv = std::max(maxv, A[i*d1 + j]);
    
  2. Exponentiate and sum

    cpp
    float sum = 0.f;
    for (unsigned j = 0; j < d1; ++j) {
      float e = std::exp(A[i*d1 + j] - maxv);
      A[i*d1 + j] = e;
      sum += e;
    }
    
  3. Normalize

    cpp
    for (unsigned j = 0; j < d1; ++j)
      A[i*d1 + j] /= sum;
    

๐Ÿ’ก Why this matters

Without this function, your attention weights would be arbitrary โ€” the softmax enforces that they represent a probability distribution over tokens.


โœ… Checkpoint

Each row should sum to ~1:

cpp
// After softmax:
for (unsigned i = 0; i < d0; ++i) {
  float s = 0;
  for (unsigned j = 0; j < d1; ++j) s += A[i*d1+j];
  assert(std::abs(s - 1.0f) < 1e-3);
}

โš™๏ธ Task 3 โ€” Multi-Head Self-Attention (PyTorch)

๐ŸŽฏ Goal

Implement the forward() of MultiHeadSelfAttention.


๐Ÿ” Conceptual Overview

Each token produces:

  • a query vector (Q) โ€” what Iโ€™m looking for
  • a key vector (K) โ€” what I contain
  • a value vector (V) โ€” what information Iโ€™ll share

We compute attention weights using:

Attention(Q,K,V)=softmax!(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^T}{\sqrt{d_k}}\right)V

๐Ÿง  Step-by-Step Implementation

  1. Project input into Q, K, V

    Python
    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)
    
  2. Reshape into heads

    Python
    B, T, C = x.shape
    q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
    k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
    v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
    

    โ†’ Shape becomes (B, num_heads, T, head_dim)

  3. Compute attention scores

    Python
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
    
  4. Apply causal mask (optional)

    Python
    if use_causal_mask:
        mask = torch.triu(torch.ones(T, T, dtype=torch.bool, device=x.device), diagonal=1)
        attn_scores.masked_fill_(mask, float('-inf'))
    
  5. Softmax across last dimension

    Python
    attn_weights = torch.softmax(attn_scores, dim=-1)
    
  6. Weighted sum of V

    Python
    attn_output = torch.matmul(attn_weights, v)
    
  7. Reshape back and output projection

    Python
    attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
    out = self.out_proj(attn_output)
    return out
    

โ“Common Questions

Q: Why divide by โˆšdโ‚–? A: To keep the dot-product magnitude consistent โ€” without it, large embeddings cause softmax saturation.

Q: Why transpose(1, 2)? A: To bring num_heads before seq_len so that each head runs its own matrix multiplication.

Q: What does โ€œcausalโ€ mean? A: In language models, a token canโ€™t attend to future tokens. Masking ensures i only sees โ‰ค i.


โœ… Checkpoint

You can test equivalence with PyTorchโ€™s built-in:

Python
torch.allclose(my_mhsa(x), nn.MultiheadAttention(embed_dim, num_heads)(x,x,x)[0], atol=1e-3)

๐Ÿง  Task 4 โ€” Decoder-Only Transformer

๐ŸŽฏ Goal

Build the GPT-style stack: embeddings โ†’ N decoder blocks โ†’ layer norm โ†’ linear output.


๐Ÿ” Flow Diagram

Input Tokens โ”€โ”€โ–ถ Embedding โ”€โ”€โ–ถ +PosEnc โ”€โ”€โ–ถ [DecoderBlock ร— N]
                                          โ”‚
                                          โ””โ”€โ”€ Each block:
                                              โ”œโ”€ LayerNorm
                                              โ”œโ”€ MHSA (+Residual)
                                              โ”œโ”€ LayerNorm
                                              โ””โ”€ FFN (+Residual)
โ”€โ”€โ–ถ LayerNorm โ”€โ–ถ Linear (vocab projection) โ”€โ–ถ Logits

๐Ÿง  Implementation Steps

  1. Embed tokens + add position encoding

    Python
    x = self.embed(x)
    x = x + self.pos_encoding[:x.size(1), :].to(x.device)
    
  2. Pass through layers

    Python
    for layer in self.layers:
        x = layer(x)
    
  3. Normalize + project

    Python
    x = self.ln_f(x)
    logits = self.head(x)
    return logits
    

โ“Questions You Might Have

Q: Why add positional encoding? A: Transformers have no notion of order; positions encode sequential structure.

Q: Why apply LayerNorm twice? A: Each normalization isolates sub-layer learning โ€” one before attention, one before feedforward.

Q: Why dropout? A: To regularize training and reduce overfitting in large models.


โœ… Checkpoint

Feed dummy input:

Python
x = torch.randint(0, vocab_size, (2, 8))
logits = model(x)
print(logits.shape)  # [2, 8, vocab_size]

๐Ÿงญ Summary

TaskConceptCore SkillCheckpoint
1Matrix multiplyMemory indexing, loopsMatches PyTorch matmul
2SoftmaxNumerical stabilityRows sum โ‰ˆ 1
3MHSAAttention mechanismMatches built-in attention
4Decoder blockArchitectural flowOutput logits valid

Next: Run your reference model side-by-side to verify numerical equivalence. Understanding why each piece exists is key โ€” youโ€™ve just built the backbone of GPT-style transformers from scratch!

Was this page helpful?