Deep LearningAdvanced
Custom Transformer
Build a transformer architecture from scratch with multi-head attention and modern optimizations
Custom Transformer
Implement the complete transformer architecture from scratch, understanding every component that powers modern LLMs.
TL;DR
Transformers use attention (softmax(QK^T/√d_k)V) to weigh all input positions when computing each output. Multi-head attention runs this in parallel with different learned projections. Modern LLMs add: RoPE for position encoding (better extrapolation), RMSNorm (faster than LayerNorm), SwiGLU activation (better quality), and KV caching (O(1) per token during generation vs O(n)).
Overview
| Aspect | Details |
|---|---|
| Difficulty | Advanced |
| Time | 5 days |
| Code | ~1000 lines |
| Prerequisites | PyTorch, linear algebra, calculus |
What You'll Build
A complete transformer implementation including:
- Scaled dot-product attention mechanism
- Multi-head attention with proper masking
- Positional encodings (sinusoidal and rotary)
- Encoder-decoder and decoder-only architectures
- A trainable GPT-style language model
- Modern optimizations (KV cache, Flash Attention concepts)
┌─────────────────────────────────────────────────────────────────────────────┐
│ Transformer Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ │
│ │ Input Tokens │ │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │Token Embeddings │ │
│ │+ Position Enc. │ │
│ └────────┬────────┘ │
│ │ │
│ ┌────────┼────────────────────────────────────────────────────────┐ │
│ │ ▼ ENCODER STACK │ │
│ │ ┌───────────────────────────────────────────────────────┐ │ │
│ │ │ Encoder Layer 1: Self-Attention → Add&Norm → FFN → Add&Norm │ │ │
│ │ └───────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌───────────────────────────────────────────────────────┐ │ │
│ │ │ Encoder Layer 2 │ │ │
│ │ └───────────────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌───────────────────────────────────────────────────────┐ │ │
│ │ │ Encoder Layer N │ │ │
│ │ └───────────────────────────────────────────────────────┘ │ │
│ └────────┼────────────────────────────────────────────────────────┘ │
│ │ │
│ ┌────────┼────────────────────────────────────────────────────────┐ │
│ │ ▼ DECODER STACK │ │
│ │ ┌───────────────────────────────────────────────────────┐ │ │
│ │ │ Decoder Layer: Masked Attn → Cross Attn → FFN │ │ │
│ │ └───────────────────────────────────────────────────────┘ │ │
│ │ │ × N layers │ │
│ └────────┼────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Linear + Softmax│ │
│ │ (vocab probs) │ │
│ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Understanding the Transformer
The Attention Revolution
┌─────────────────────────────────────────────────────────────────────────────┐
│ Transformer Components │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ │
│ │ TRANSFORMER │ │
│ └──────┬──────┘ │
│ ┌──────────────────────┼──────────────────────┐ │
│ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ Attention│ │Architect.│ │Embeddings│ │ Variants │ │
│ ├──────────┤ ├──────────┤ ├──────────┤ ├──────────┤ │
│ │• Self- │ │• Encoder │ │• Token │ │• Encoder │ │
│ │ Attention│ │• Decoder │ │ Embed. │ │ Only │ │
│ │• Cross- │ │• Feed- │ │• Position│ │ (BERT) │ │
│ │ Attention│ │ Forward │ │ - Sinus.│ │• Decoder │ │
│ │• Multi- │ │• Layer │ │ - RoPE │ │ Only │ │
│ │ Head │ │ Norm │ │ - ALiBi │ │ (GPT) │ │
│ │• Scaled │ │• Residual│ │ │ │• Enc-Dec │ │
│ │ Dot-Prod│ │ Connect.│ │ │ │ (T5) │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘The Attention Mechanism
┌─────────────────────────────────────────────────────────────────────────────┐
│ Scaled Dot-Product Attention │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Attention(Q, K, V) = softmax(QK^T / √d_k) × V │
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Query Q │ │ Key K │ │ Value V │ │
│ │ (n×d_k) │ │ (m×d_k) │ │ (m×d_v) │ │
│ └────┬────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │
│ │ ┌────────┘ │ │
│ │ │ │ │
│ ▼ ▼ │ │
│ ┌─────────────┐ │ │
│ │ Q × K^T │ ◄── Matrix multiply (n×d_k) × (d_k×m) = (n×m) │
│ └──────┬──────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────┐ │ │
│ │ ÷ √d_k │ ◄── Scale to prevent softmax saturation │
│ └──────┬──────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────┐ │ │
│ │ Mask (opt) │ ◄── Set future positions to -∞ (causal) │
│ └──────┬──────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────┐ │ │
│ │ Softmax │ ◄── Convert to attention weights (sum=1 per row) │
│ └──────┬──────┘ │ │
│ │ │ │
│ └─────────────────────────┤ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Attn × V │ ◄── Weighted sum of values │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Output │ │
│ │ (n×d_v) │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Project Setup
Environment Setup
# Create project directory
mkdir custom-transformer && cd custom-transformer
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install torch torchvision
pip install numpy matplotlib
pip install tiktoken # For tokenization
pip install tqdm wandb
pip install einops # Tensor operationsProject Structure
custom-transformer/
├── src/
│ ├── __init__.py
│ ├── attention.py
│ ├── embeddings.py
│ ├── encoder.py
│ ├── decoder.py
│ ├── transformer.py
│ ├── gpt.py
│ └── utils.py
├── train_gpt.py
├── generate.py
├── config.py
└── requirements.txtRequirements
# requirements.txt
torch>=2.0.0
numpy>=1.24.0
matplotlib>=3.7.0
tiktoken>=0.5.0
tqdm>=4.66.0
wandb>=0.16.0
einops>=0.7.0Part 1: Attention Mechanisms
Scaled Dot-Product Attention
The fundamental building block of transformers:
Attention(Q, K, V) = softmax(QK^T / √d_k) × V# src/attention.py
"""
Attention mechanisms for transformer models.
Implements scaled dot-product attention and multi-head attention.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
dropout: float = 0.0,
training: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute scaled dot-product attention.
Args:
query: Query tensor of shape (batch, heads, seq_len, d_k)
key: Key tensor of shape (batch, heads, seq_len, d_k)
value: Value tensor of shape (batch, heads, seq_len, d_v)
mask: Optional mask tensor (1 = attend, 0 = mask)
dropout: Dropout probability
training: Whether in training mode
Returns:
Tuple of (output, attention_weights)
"""
d_k = query.size(-1)
# Compute attention scores: (batch, heads, seq_len, seq_len)
# Scale by sqrt(d_k) to prevent softmax saturation
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask if provided
if mask is not None:
# Mask positions with very negative number before softmax
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
# Handle NaN from all-masked positions
attention_weights = torch.nan_to_num(attention_weights, nan=0.0)
# Apply dropout during training
if dropout > 0.0 and training:
attention_weights = F.dropout(attention_weights, p=dropout, training=training)
# Compute weighted sum of values
output = torch.matmul(attention_weights, value)
return output, attention_weights
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention mechanism.
Allows the model to jointly attend to information from different
representation subspaces at different positions.
Multi-Head Attention = Concat(head_1, ..., head_h) W^O
where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
"""
def __init__(
self,
d_model: int,
num_heads: int,
dropout: float = 0.0,
bias: bool = True,
):
"""
Args:
d_model: Model dimension
num_heads: Number of attention heads
dropout: Dropout probability
bias: Whether to use bias in linear projections
"""
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # Dimension per head
self.dropout = dropout
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model, bias=bias)
self.W_k = nn.Linear(d_model, d_model, bias=bias)
self.W_v = nn.Linear(d_model, d_model, bias=bias)
# Output projection
self.W_o = nn.Linear(d_model, d_model, bias=bias)
# Store attention weights for visualization
self.attention_weights = None
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
need_weights: bool = False,
) -> torch.Tensor:
"""
Forward pass.
Args:
query: Query tensor (batch, seq_len, d_model)
key: Key tensor (batch, seq_len, d_model)
value: Value tensor (batch, seq_len, d_model)
mask: Optional attention mask
need_weights: Whether to return attention weights
Returns:
Output tensor (batch, seq_len, d_model)
"""
batch_size = query.size(0)
# 1. Linear projections
Q = self.W_q(query) # (batch, seq_len, d_model)
K = self.W_k(key)
V = self.W_v(value)
# 2. Reshape for multi-head attention
# (batch, seq_len, d_model) -> (batch, num_heads, seq_len, d_k)
Q = rearrange(Q, 'b s (h d) -> b h s d', h=self.num_heads)
K = rearrange(K, 'b s (h d) -> b h s d', h=self.num_heads)
V = rearrange(V, 'b s (h d) -> b h s d', h=self.num_heads)
# 3. Expand mask for multi-head
if mask is not None:
# Add head dimension: (batch, 1, seq, seq) or (batch, 1, 1, seq)
if mask.dim() == 2:
mask = mask.unsqueeze(1).unsqueeze(1)
elif mask.dim() == 3:
mask = mask.unsqueeze(1)
# 4. Apply attention
output, attention_weights = scaled_dot_product_attention(
Q, K, V,
mask=mask,
dropout=self.dropout,
training=self.training,
)
if need_weights:
self.attention_weights = attention_weights
# 5. Reshape back
# (batch, num_heads, seq_len, d_k) -> (batch, seq_len, d_model)
output = rearrange(output, 'b h s d -> b s (h d)')
# 6. Output projection
output = self.W_o(output)
return output
class CausalSelfAttention(nn.Module):
"""
Causal (masked) self-attention for autoregressive models.
Each position can only attend to previous positions.
"""
def __init__(
self,
d_model: int,
num_heads: int,
max_seq_len: int = 2048,
dropout: float = 0.0,
bias: bool = True,
):
super().__init__()
self.num_heads = num_heads
self.d_model = d_model
self.d_k = d_model // num_heads
self.dropout = dropout
# Combined QKV projection for efficiency
self.c_attn = nn.Linear(d_model, 3 * d_model, bias=bias)
self.c_proj = nn.Linear(d_model, d_model, bias=bias)
# Causal mask (lower triangular)
# Register as buffer so it moves with the model
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer('causal_mask', mask.view(1, 1, max_seq_len, max_seq_len))
def forward(
self,
x: torch.Tensor,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
"""
Forward pass with optional KV caching.
Args:
x: Input tensor (batch, seq_len, d_model)
use_cache: Whether to return KV cache
past_kv: Previous KV cache for incremental decoding
Returns:
Tuple of (output, optional_kv_cache)
"""
batch_size, seq_len, _ = x.shape
# Compute Q, K, V
qkv = self.c_attn(x)
Q, K, V = qkv.split(self.d_model, dim=2)
# Reshape for multi-head
Q = rearrange(Q, 'b s (h d) -> b h s d', h=self.num_heads)
K = rearrange(K, 'b s (h d) -> b h s d', h=self.num_heads)
V = rearrange(V, 'b s (h d) -> b h s d', h=self.num_heads)
# Handle KV caching for efficient generation
if past_kv is not None:
past_k, past_v = past_kv
K = torch.cat([past_k, K], dim=2)
V = torch.cat([past_v, V], dim=2)
present_kv = (K, V) if use_cache else None
kv_seq_len = K.size(2)
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply causal mask
# For cached generation, we need to mask based on total sequence length
causal_mask = self.causal_mask[:, :, kv_seq_len - seq_len:kv_seq_len, :kv_seq_len]
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
# Softmax and dropout
attention_weights = F.softmax(scores, dim=-1)
if self.dropout > 0.0 and self.training:
attention_weights = F.dropout(attention_weights, p=self.dropout)
# Compute output
output = torch.matmul(attention_weights, V)
output = rearrange(output, 'b h s d -> b s (h d)')
output = self.c_proj(output)
return output, present_kv
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention (GQA) - used in Llama 2 and similar models.
Uses fewer key-value heads than query heads to reduce memory.
If num_kv_heads == num_heads: Multi-Head Attention
If num_kv_heads == 1: Multi-Query Attention
Otherwise: Grouped Query Attention
"""
def __init__(
self,
d_model: int,
num_heads: int,
num_kv_heads: int,
max_seq_len: int = 2048,
dropout: float = 0.0,
):
super().__init__()
assert num_heads % num_kv_heads == 0, "num_heads must be divisible by num_kv_heads"
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_queries_per_kv = num_heads // num_kv_heads
self.d_model = d_model
self.d_k = d_model // num_heads
self.dropout = dropout
# Separate projections
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
# Causal mask
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer('causal_mask', mask.view(1, 1, max_seq_len, max_seq_len))
def forward(
self,
x: torch.Tensor,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
batch_size, seq_len, _ = x.shape
# Compute Q, K, V with different head counts
Q = self.W_q(x)
K = self.W_k(x)
V = self.W_v(x)
# Reshape
Q = rearrange(Q, 'b s (h d) -> b h s d', h=self.num_heads)
K = rearrange(K, 'b s (h d) -> b h s d', h=self.num_kv_heads)
V = rearrange(V, 'b s (h d) -> b h s d', h=self.num_kv_heads)
# Handle KV caching
if past_kv is not None:
past_k, past_v = past_kv
K = torch.cat([past_k, K], dim=2)
V = torch.cat([past_v, V], dim=2)
present_kv = (K, V) if use_cache else None
kv_seq_len = K.size(2)
# Expand K, V heads to match Q heads
# (batch, num_kv_heads, seq, d_k) -> (batch, num_heads, seq, d_k)
K = repeat(K, 'b h s d -> b (h g) s d', g=self.num_queries_per_kv)
V = repeat(V, 'b h s d -> b (h g) s d', g=self.num_queries_per_kv)
# Attention computation
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
causal_mask = self.causal_mask[:, :, kv_seq_len - seq_len:kv_seq_len, :kv_seq_len]
scores = scores.masked_fill(causal_mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
if self.dropout > 0.0 and self.training:
attention_weights = F.dropout(attention_weights, p=self.dropout)
output = torch.matmul(attention_weights, V)
output = rearrange(output, 'b h s d -> b s (h d)')
output = self.W_o(output)
return output, present_kvPart 2: Embeddings and Position Encodings
Token and Position Embeddings
# src/embeddings.py
"""
Embedding layers for transformer models.
Includes token embeddings and various position encoding schemes.
"""
import math
from typing import Optional
import torch
import torch.nn as nn
class TokenEmbedding(nn.Module):
"""Token embedding layer with optional scaling."""
def __init__(self, vocab_size: int, d_model: int, scale: bool = True):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
self.scale = scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Token indices (batch, seq_len)
Returns:
Embeddings (batch, seq_len, d_model)
"""
emb = self.embedding(x)
if self.scale:
emb = emb * math.sqrt(self.d_model)
return emb
class SinusoidalPositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding from "Attention Is All You Need".
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
Advantages:
- Can extrapolate to longer sequences than seen during training
- No learnable parameters
"""
def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# Create position encoding matrix
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Register as buffer (not a parameter)
pe = pe.unsqueeze(0) # (1, max_seq_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input embeddings (batch, seq_len, d_model)
Returns:
Position-encoded embeddings (batch, seq_len, d_model)
"""
seq_len = x.size(1)
x = x + self.pe[:, :seq_len, :]
return self.dropout(x)
class LearnedPositionalEncoding(nn.Module):
"""
Learned positional embeddings (used in BERT, GPT-2).
Each position has a learnable embedding vector.
"""
def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
super().__init__()
self.position_embedding = nn.Embedding(max_seq_len, d_model)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input embeddings (batch, seq_len, d_model)
Returns:
Position-encoded embeddings (batch, seq_len, d_model)
"""
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device)
pos_emb = self.position_embedding(positions)
return self.dropout(x + pos_emb)
class RotaryPositionalEmbedding(nn.Module):
"""
Rotary Position Embedding (RoPE) from "RoFormer".
Used in LLaMA, GPT-NeoX, and modern LLMs.
RoPE applies rotation to query and key vectors based on position,
encoding relative position information into the attention mechanism.
Advantages:
- Better extrapolation to longer sequences
- Encodes relative positions naturally
- No additional parameters needed
"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = 4096, base: int = 10000):
super().__init__()
self.d_k = d_model // num_heads
self.max_seq_len = max_seq_len
self.base = base
# Precompute rotation matrices
self._compute_freqs(max_seq_len)
def _compute_freqs(self, seq_len: int):
"""Precompute frequency tensors for RoPE."""
# Compute inverse frequencies
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.d_k, 2).float() / self.d_k)
)
self.register_buffer('inv_freq', inv_freq)
# Compute rotation angles for each position
t = torch.arange(seq_len).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq) # (seq_len, d_k/2)
# Create rotation matrix components
self.register_buffer('cos', torch.cos(freqs))
self.register_buffer('sin', torch.sin(freqs))
def _rotate_half(self, x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
start_pos: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embedding to query and key tensors.
Args:
q: Query tensor (batch, num_heads, seq_len, d_k)
k: Key tensor (batch, num_heads, seq_len, d_k)
start_pos: Starting position (for cached generation)
Returns:
Tuple of (rotated_q, rotated_k)
"""
seq_len = q.size(2)
# Get rotation components for this sequence
cos = self.cos[start_pos:start_pos + seq_len, :].unsqueeze(0).unsqueeze(0)
sin = self.sin[start_pos:start_pos + seq_len, :].unsqueeze(0).unsqueeze(0)
# Repeat for full dimension
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
# Apply rotation
q_rotated = (q * cos) + (self._rotate_half(q) * sin)
k_rotated = (k * cos) + (self._rotate_half(k) * sin)
return q_rotated, k_rotated
class ALiBi(nn.Module):
"""
Attention with Linear Biases (ALiBi) from "Train Short, Test Long".
Adds a linear bias to attention scores based on distance.
Advantages:
- Zero learnable parameters
- Strong length extrapolation
"""
def __init__(self, num_heads: int):
super().__init__()
self.num_heads = num_heads
# Compute slopes for each head (geometric sequence)
slopes = self._get_slopes(num_heads)
self.register_buffer('slopes', slopes.view(1, num_heads, 1, 1))
def _get_slopes(self, n: int) -> torch.Tensor:
"""Get ALiBi slopes for n heads."""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n).is_integer():
return torch.tensor(get_slopes_power_of_2(n))
# Handle non-power-of-2
closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes = get_slopes_power_of_2(closest_power_of_2)
extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[0::2]
slopes = slopes + extra_slopes[:n - closest_power_of_2]
return torch.tensor(slopes)
def forward(self, attention_scores: torch.Tensor) -> torch.Tensor:
"""
Add ALiBi bias to attention scores.
Args:
attention_scores: (batch, num_heads, query_len, key_len)
Returns:
Biased attention scores
"""
query_len = attention_scores.size(2)
key_len = attention_scores.size(3)
# Create distance matrix
# For causal attention: distance[i, j] = i - j (0 for j > i)
positions = torch.arange(key_len, device=attention_scores.device)
distance = positions.unsqueeze(0) - positions.unsqueeze(1)
distance = distance.unsqueeze(0).unsqueeze(0) # (1, 1, key_len, key_len)
# Take only the relevant query positions
if query_len < key_len:
distance = distance[:, :, -query_len:, :]
# Apply slopes (negative bias for distant positions)
bias = -self.slopes * distance.abs()
return attention_scores + biasPart 3: Feed-Forward Network and Layer Normalization
# src/layers.py
"""
Core transformer layers: Feed-forward networks and layer normalization.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeedForward(nn.Module):
"""
Position-wise Feed-Forward Network.
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
Uses a hidden dimension typically 4x the model dimension.
"""
def __init__(
self,
d_model: int,
d_ff: Optional[int] = None,
dropout: float = 0.1,
activation: str = "relu",
bias: bool = True,
):
"""
Args:
d_model: Model dimension
d_ff: Feed-forward hidden dimension (default: 4 * d_model)
dropout: Dropout probability
activation: Activation function ("relu" or "gelu")
bias: Whether to use bias
"""
super().__init__()
d_ff = d_ff or 4 * d_model
self.linear1 = nn.Linear(d_model, d_ff, bias=bias)
self.linear2 = nn.Linear(d_ff, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
if activation == "relu":
self.activation = F.relu
elif activation == "gelu":
self.activation = F.gelu
else:
raise ValueError(f"Unknown activation: {activation}")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor (batch, seq_len, d_model)
Returns:
Output tensor (batch, seq_len, d_model)
"""
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class SwiGLU(nn.Module):
"""
SwiGLU activation function from "GLU Variants Improve Transformer".
Used in LLaMA and PaLM.
SwiGLU(x) = Swish(xW_gate) ⊙ (xW_up)
More expressive than standard FFN with similar parameter count.
"""
def __init__(
self,
d_model: int,
d_ff: Optional[int] = None,
dropout: float = 0.1,
bias: bool = False,
):
super().__init__()
# LLaMA uses 2/3 * 4 * d_model for intermediate size with SwiGLU
d_ff = d_ff or int(2 * 4 * d_model / 3)
# Make divisible by 256 for efficiency
d_ff = ((d_ff + 255) // 256) * 256
self.w_gate = nn.Linear(d_model, d_ff, bias=bias)
self.w_up = nn.Linear(d_model, d_ff, bias=bias)
self.w_down = nn.Linear(d_ff, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Gate with SiLU (Swish) activation
gate = F.silu(self.w_gate(x))
# Up projection
up = self.w_up(x)
# Element-wise product
x = gate * up
x = self.dropout(x)
# Down projection
x = self.w_down(x)
return x
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Used in LLaMA and other modern LLMs.
Simpler and more efficient than LayerNorm (no mean computation).
RMSNorm(x) = x / RMS(x) * g
where RMS(x) = sqrt(mean(x^2))
"""
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute RMS
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
# Normalize and scale
return x / rms * self.weight
class LayerNorm(nn.Module):
"""
Standard Layer Normalization with optional bias.
LayerNorm(x) = (x - mean(x)) / std(x) * g + b
"""
def __init__(self, d_model: int, eps: float = 1e-5, bias: bool = True):
super().__init__()
self.weight = nn.Parameter(torch.ones(d_model))
self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True, unbiased=False)
x = (x - mean) / (std + self.eps)
x = x * self.weight
if self.bias is not None:
x = x + self.bias
return xPart 4: Encoder Architecture
# src/encoder.py
"""
Transformer Encoder implementation.
Used in BERT-style models and encoder-decoder architectures.
"""
from typing import Optional
import torch
import torch.nn as nn
from .attention import MultiHeadAttention
from .layers import FeedForward, LayerNorm
class EncoderLayer(nn.Module):
"""
Single Transformer Encoder Layer.
Consists of:
1. Multi-head self-attention (with residual connection and layer norm)
2. Position-wise feed-forward network (with residual connection and layer norm)
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
activation: str = "relu",
pre_norm: bool = False,
):
"""
Args:
d_model: Model dimension
num_heads: Number of attention heads
d_ff: Feed-forward hidden dimension
dropout: Dropout probability
activation: Activation function
pre_norm: Use pre-normalization (GPT-style) vs post-norm (original)
"""
super().__init__()
self.pre_norm = pre_norm
# Self-attention sublayer
self.self_attention = MultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
)
self.norm1 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
# Feed-forward sublayer
self.feed_forward = FeedForward(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=activation,
)
self.norm2 = LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: Input tensor (batch, seq_len, d_model)
mask: Optional padding mask
Returns:
Output tensor (batch, seq_len, d_model)
"""
if self.pre_norm:
# Pre-normalization (more stable for deep networks)
# x + Attention(Norm(x))
residual = x
x = self.norm1(x)
x = self.self_attention(x, x, x, mask=mask)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
x = self.dropout2(x)
x = residual + x
else:
# Post-normalization (original transformer)
# Norm(x + Attention(x))
residual = x
x = self.self_attention(x, x, x, mask=mask)
x = self.dropout1(x)
x = self.norm1(residual + x)
residual = x
x = self.feed_forward(x)
x = self.dropout2(x)
x = self.norm2(residual + x)
return x
class TransformerEncoder(nn.Module):
"""
Full Transformer Encoder (stack of encoder layers).
"""
def __init__(
self,
num_layers: int,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
activation: str = "relu",
pre_norm: bool = False,
):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
activation=activation,
pre_norm=pre_norm,
)
for _ in range(num_layers)
])
# Final layer norm for pre-norm architecture
self.norm = LayerNorm(d_model) if pre_norm else None
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: Input tensor (batch, seq_len, d_model)
mask: Optional padding mask
Returns:
Output tensor (batch, seq_len, d_model)
"""
for layer in self.layers:
x = layer(x, mask=mask)
if self.norm is not None:
x = self.norm(x)
return xPart 5: Decoder Architecture
# src/decoder.py
"""
Transformer Decoder implementation.
Used in GPT-style models and encoder-decoder architectures.
"""
from typing import Optional, Tuple
import torch
import torch.nn as nn
from .attention import MultiHeadAttention, CausalSelfAttention
from .layers import FeedForward, LayerNorm, RMSNorm, SwiGLU
class DecoderLayer(nn.Module):
"""
Single Transformer Decoder Layer.
For decoder-only (GPT-style):
1. Masked multi-head self-attention
2. Feed-forward network
For encoder-decoder:
1. Masked multi-head self-attention
2. Cross-attention to encoder outputs
3. Feed-forward network
"""
def __init__(
self,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
activation: str = "relu",
pre_norm: bool = True,
has_cross_attention: bool = False,
):
super().__init__()
self.pre_norm = pre_norm
self.has_cross_attention = has_cross_attention
# Masked self-attention
self.self_attention = MultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
)
self.norm1 = LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
# Cross-attention (only for encoder-decoder)
if has_cross_attention:
self.cross_attention = MultiHeadAttention(
d_model=d_model,
num_heads=num_heads,
dropout=dropout,
)
self.norm2 = LayerNorm(d_model)
self.dropout2 = nn.Dropout(dropout)
# Feed-forward
self.feed_forward = FeedForward(
d_model=d_model,
d_ff=d_ff,
dropout=dropout,
activation=activation,
)
self.norm_ff = LayerNorm(d_model)
self.dropout_ff = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
encoder_output: Optional[torch.Tensor] = None,
self_attn_mask: Optional[torch.Tensor] = None,
cross_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: Decoder input (batch, tgt_len, d_model)
encoder_output: Encoder output for cross-attention (batch, src_len, d_model)
self_attn_mask: Causal mask for self-attention
cross_attn_mask: Padding mask for cross-attention
Returns:
Output tensor (batch, tgt_len, d_model)
"""
if self.pre_norm:
# Self-attention with pre-norm
residual = x
x = self.norm1(x)
x = self.self_attention(x, x, x, mask=self_attn_mask)
x = self.dropout1(x)
x = residual + x
# Cross-attention (if encoder-decoder)
if self.has_cross_attention and encoder_output is not None:
residual = x
x = self.norm2(x)
x = self.cross_attention(x, encoder_output, encoder_output, mask=cross_attn_mask)
x = self.dropout2(x)
x = residual + x
# Feed-forward
residual = x
x = self.norm_ff(x)
x = self.feed_forward(x)
x = self.dropout_ff(x)
x = residual + x
else:
# Post-norm version
residual = x
x = self.self_attention(x, x, x, mask=self_attn_mask)
x = self.dropout1(x)
x = self.norm1(residual + x)
if self.has_cross_attention and encoder_output is not None:
residual = x
x = self.cross_attention(x, encoder_output, encoder_output, mask=cross_attn_mask)
x = self.dropout2(x)
x = self.norm2(residual + x)
residual = x
x = self.feed_forward(x)
x = self.dropout_ff(x)
x = self.norm_ff(residual + x)
return x
class TransformerDecoder(nn.Module):
"""
Full Transformer Decoder (stack of decoder layers).
"""
def __init__(
self,
num_layers: int,
d_model: int,
num_heads: int,
d_ff: int,
dropout: float = 0.1,
activation: str = "relu",
pre_norm: bool = True,
has_cross_attention: bool = False,
):
super().__init__()
self.layers = nn.ModuleList([
DecoderLayer(
d_model=d_model,
num_heads=num_heads,
d_ff=d_ff,
dropout=dropout,
activation=activation,
pre_norm=pre_norm,
has_cross_attention=has_cross_attention,
)
for _ in range(num_layers)
])
self.norm = LayerNorm(d_model) if pre_norm else None
def forward(
self,
x: torch.Tensor,
encoder_output: Optional[torch.Tensor] = None,
self_attn_mask: Optional[torch.Tensor] = None,
cross_attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
for layer in self.layers:
x = layer(
x,
encoder_output=encoder_output,
self_attn_mask=self_attn_mask,
cross_attn_mask=cross_attn_mask,
)
if self.norm is not None:
x = self.norm(x)
return xPart 6: Complete GPT-Style Model
# src/gpt.py
"""
Complete GPT-style decoder-only transformer.
"""
from typing import Optional, Tuple, Dict, Any
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from .attention import CausalSelfAttention, GroupedQueryAttention
from .embeddings import (
TokenEmbedding,
LearnedPositionalEncoding,
RotaryPositionalEmbedding,
)
from .layers import FeedForward, SwiGLU, LayerNorm, RMSNorm
@dataclass
class GPTConfig:
"""Configuration for GPT model."""
vocab_size: int = 50257
max_seq_len: int = 1024
d_model: int = 768
num_heads: int = 12
num_layers: int = 12
d_ff: Optional[int] = None # Default: 4 * d_model
dropout: float = 0.1
bias: bool = True
# Modern options
use_rope: bool = False
use_swiglu: bool = False
use_rms_norm: bool = False
num_kv_heads: Optional[int] = None # For GQA
class GPTBlock(nn.Module):
"""
Single GPT block with causal self-attention and feed-forward.
"""
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
# Layer normalization
norm_cls = RMSNorm if config.use_rms_norm else LayerNorm
self.ln1 = norm_cls(config.d_model)
self.ln2 = norm_cls(config.d_model)
# Attention
if config.num_kv_heads is not None:
self.attn = GroupedQueryAttention(
d_model=config.d_model,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
)
else:
self.attn = CausalSelfAttention(
d_model=config.d_model,
num_heads=config.num_heads,
max_seq_len=config.max_seq_len,
dropout=config.dropout,
bias=config.bias,
)
# Feed-forward
if config.use_swiglu:
self.ffn = SwiGLU(
d_model=config.d_model,
d_ff=config.d_ff,
dropout=config.dropout,
)
else:
self.ffn = FeedForward(
d_model=config.d_model,
d_ff=config.d_ff or 4 * config.d_model,
dropout=config.dropout,
activation="gelu",
)
def forward(
self,
x: torch.Tensor,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Attention with residual
residual = x
x = self.ln1(x)
attn_out, present_kv = self.attn(x, use_cache=use_cache, past_kv=past_kv)
x = residual + attn_out
# Feed-forward with residual
residual = x
x = self.ln2(x)
x = self.ffn(x)
x = residual + x
return x, present_kv
class GPT(nn.Module):
"""
GPT-style Transformer Language Model.
Architecture:
- Token embeddings + Position embeddings
- N x (LayerNorm + Attention + LayerNorm + FFN)
- Final LayerNorm
- Language model head (tied with embeddings)
"""
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
# Token embeddings
self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
# Position embeddings (unless using RoPE)
if not config.use_rope:
self.position_embedding = nn.Embedding(config.max_seq_len, config.d_model)
else:
self.position_embedding = None
self.rope = RotaryPositionalEmbedding(
config.d_model,
config.num_heads,
config.max_seq_len,
)
self.dropout = nn.Dropout(config.dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
GPTBlock(config) for _ in range(config.num_layers)
])
# Final layer norm
norm_cls = RMSNorm if config.use_rms_norm else LayerNorm
self.ln_f = norm_cls(config.d_model)
# Language model head
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying (share embedding weights with output)
self.lm_head.weight = self.token_embedding.weight
# Initialize weights
self.apply(self._init_weights)
# Count parameters
n_params = sum(p.numel() for p in self.parameters())
print(f"GPT model with {n_params / 1e6:.2f}M parameters")
def _init_weights(self, module: nn.Module):
"""Initialize weights."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[list] = None,
) -> Dict[str, Any]:
"""
Forward pass.
Args:
input_ids: Input token IDs (batch, seq_len)
labels: Target labels for loss computation (batch, seq_len)
use_cache: Whether to return KV cache
past_kv: Previous KV cache for generation
Returns:
Dictionary with logits, loss (if labels provided), and cache
"""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Get position offset for cached generation
start_pos = 0
if past_kv is not None and past_kv[0] is not None:
start_pos = past_kv[0][0].size(2)
# Token embeddings
x = self.token_embedding(input_ids)
# Position embeddings (if not using RoPE)
if self.position_embedding is not None:
positions = torch.arange(start_pos, start_pos + seq_len, device=device)
x = x + self.position_embedding(positions)
x = self.dropout(x)
# Process through transformer blocks
present_kv = []
for i, block in enumerate(self.blocks):
layer_past = past_kv[i] if past_kv is not None else None
x, kv = block(x, use_cache=use_cache, past_kv=layer_past)
if use_cache:
present_kv.append(kv)
# Final layer norm
x = self.ln_f(x)
# Language model head
logits = self.lm_head(x)
# Compute loss if labels provided
loss = None
if labels is not None:
# Shift logits and labels for next-token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return {
"logits": logits,
"loss": loss,
"past_key_values": present_kv if use_cache else None,
}
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
do_sample: bool = True,
) -> torch.Tensor:
"""
Generate text autoregressively.
Args:
input_ids: Prompt token IDs (batch, seq_len)
max_new_tokens: Maximum new tokens to generate
temperature: Sampling temperature (higher = more random)
top_k: Keep only top-k tokens for sampling
top_p: Keep tokens with cumulative probability < top_p (nucleus)
do_sample: Whether to sample (False = greedy)
Returns:
Generated token IDs (batch, seq_len + max_new_tokens)
"""
self.eval()
for _ in range(max_new_tokens):
# Crop context if needed
idx_cond = input_ids if input_ids.size(1) <= self.config.max_seq_len else \
input_ids[:, -self.config.max_seq_len:]
# Forward pass
outputs = self(idx_cond, use_cache=False)
logits = outputs["logits"]
# Get logits for last position
logits = logits[:, -1, :] / temperature
# Apply top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Apply top-p (nucleus) filtering
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# Sample or greedy
probs = F.softmax(logits, dim=-1)
if do_sample:
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(probs, dim=-1, keepdim=True)
# Append to sequence
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
def create_gpt_small() -> GPT:
"""Create GPT-small (124M params like GPT-2 small)."""
config = GPTConfig(
vocab_size=50257,
max_seq_len=1024,
d_model=768,
num_heads=12,
num_layers=12,
dropout=0.1,
)
return GPT(config)
def create_gpt_medium() -> GPT:
"""Create GPT-medium (350M params like GPT-2 medium)."""
config = GPTConfig(
vocab_size=50257,
max_seq_len=1024,
d_model=1024,
num_heads=16,
num_layers=24,
dropout=0.1,
)
return GPT(config)
def create_llama_style(
vocab_size: int = 32000,
d_model: int = 512,
num_heads: int = 8,
num_layers: int = 8,
) -> GPT:
"""Create a LLaMA-style model with modern components."""
config = GPTConfig(
vocab_size=vocab_size,
max_seq_len=2048,
d_model=d_model,
num_heads=num_heads,
num_layers=num_layers,
dropout=0.0,
bias=False,
use_rope=True,
use_swiglu=True,
use_rms_norm=True,
num_kv_heads=num_heads // 2, # GQA
)
return GPT(config)Part 7: Training Script
# train_gpt.py
"""
Training script for GPT model.
"""
import os
import math
import time
from typing import Optional
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import tiktoken
from tqdm import tqdm
from src.gpt import GPT, GPTConfig
@dataclass
class TrainingConfig:
"""Training configuration."""
# Data
data_path: str = "data/input.txt"
max_seq_len: int = 256
# Model
vocab_size: int = 50257
d_model: int = 384
num_heads: int = 6
num_layers: int = 6
dropout: float = 0.1
# Training
batch_size: int = 64
learning_rate: float = 3e-4
weight_decay: float = 0.1
num_epochs: int = 10
warmup_steps: int = 100
grad_clip: float = 1.0
# Evaluation
eval_interval: int = 500
eval_steps: int = 200
# Checkpointing
save_dir: str = "./checkpoints"
save_interval: int = 1000
class TextDataset(Dataset):
"""Simple text dataset for language modeling."""
def __init__(self, text: str, tokenizer, seq_len: int):
self.tokenizer = tokenizer
self.seq_len = seq_len
# Tokenize entire text
self.tokens = torch.tensor(tokenizer.encode(text), dtype=torch.long)
print(f"Dataset size: {len(self.tokens):,} tokens")
def __len__(self):
return max(0, len(self.tokens) - self.seq_len - 1)
def __getitem__(self, idx):
# Get chunk of tokens
chunk = self.tokens[idx:idx + self.seq_len + 1]
x = chunk[:-1]
y = chunk[1:]
return x, y
class Trainer:
"""Training loop for GPT."""
def __init__(self, config: TrainingConfig):
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# Initialize tokenizer
self.tokenizer = tiktoken.get_encoding("gpt2")
# Load data
self._load_data()
# Initialize model
self._init_model()
# Initialize optimizer
self._init_optimizer()
# Tracking
self.global_step = 0
self.best_val_loss = float('inf')
def _load_data(self):
"""Load and prepare dataset."""
with open(self.config.data_path, 'r', encoding='utf-8') as f:
text = f.read()
# Split into train/val
n = len(text)
train_text = text[:int(0.9 * n)]
val_text = text[int(0.9 * n):]
self.train_dataset = TextDataset(train_text, self.tokenizer, self.config.max_seq_len)
self.val_dataset = TextDataset(val_text, self.tokenizer, self.config.max_seq_len)
self.train_loader = DataLoader(
self.train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size=self.config.batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
)
def _init_model(self):
"""Initialize GPT model."""
model_config = GPTConfig(
vocab_size=self.config.vocab_size,
max_seq_len=self.config.max_seq_len,
d_model=self.config.d_model,
num_heads=self.config.num_heads,
num_layers=self.config.num_layers,
dropout=self.config.dropout,
)
self.model = GPT(model_config).to(self.device)
# Compile model for faster training (PyTorch 2.0+)
if hasattr(torch, 'compile'):
print("Compiling model...")
self.model = torch.compile(self.model)
def _init_optimizer(self):
"""Initialize optimizer with weight decay."""
# Separate parameters that should have weight decay
decay_params = []
no_decay_params = []
for name, param in self.model.named_parameters():
if param.requires_grad:
if 'weight' in name and 'norm' not in name and 'embedding' not in name:
decay_params.append(param)
else:
no_decay_params.append(param)
optim_groups = [
{"params": decay_params, "weight_decay": self.config.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0},
]
self.optimizer = torch.optim.AdamW(
optim_groups,
lr=self.config.learning_rate,
betas=(0.9, 0.95),
)
def _get_lr(self, step: int) -> float:
"""Cosine learning rate schedule with warmup."""
if step < self.config.warmup_steps:
return self.config.learning_rate * step / self.config.warmup_steps
# Cosine decay
progress = (step - self.config.warmup_steps) / max(1, self.config.num_epochs * len(self.train_loader) - self.config.warmup_steps)
return self.config.learning_rate * 0.5 * (1.0 + math.cos(math.pi * progress))
def train_step(self, batch):
"""Single training step."""
x, y = batch
x, y = x.to(self.device), y.to(self.device)
# Forward pass
outputs = self.model(x, labels=y)
loss = outputs["loss"]
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip)
# Update learning rate
lr = self._get_lr(self.global_step)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
# Optimizer step
self.optimizer.step()
return loss.item()
@torch.no_grad()
def evaluate(self):
"""Evaluate on validation set."""
self.model.eval()
losses = []
for batch in self.val_loader:
x, y = batch
x, y = x.to(self.device), y.to(self.device)
outputs = self.model(x, labels=y)
losses.append(outputs["loss"].item())
if len(losses) >= self.config.eval_steps:
break
self.model.train()
return sum(losses) / len(losses)
def save_checkpoint(self, filename: str):
"""Save model checkpoint."""
os.makedirs(self.config.save_dir, exist_ok=True)
path = os.path.join(self.config.save_dir, filename)
torch.save({
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"global_step": self.global_step,
"config": self.config,
}, path)
print(f"Saved checkpoint to {path}")
def train(self):
"""Main training loop."""
print(f"\nStarting training for {self.config.num_epochs} epochs...")
self.model.train()
for epoch in range(self.config.num_epochs):
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}")
for batch in pbar:
loss = self.train_step(batch)
self.global_step += 1
# Update progress bar
pbar.set_postfix({
"loss": f"{loss:.4f}",
"lr": f"{self._get_lr(self.global_step):.2e}",
})
# Evaluation
if self.global_step % self.config.eval_interval == 0:
val_loss = self.evaluate()
print(f"\nStep {self.global_step} | Val Loss: {val_loss:.4f} | Perplexity: {math.exp(val_loss):.2f}")
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self.save_checkpoint("best_model.pt")
# Periodic save
if self.global_step % self.config.save_interval == 0:
self.save_checkpoint(f"checkpoint_{self.global_step}.pt")
# Final save
self.save_checkpoint("final_model.pt")
print(f"\nTraining complete! Best val loss: {self.best_val_loss:.4f}")
def main():
config = TrainingConfig(
data_path="data/shakespeare.txt", # Use any text file
max_seq_len=256,
d_model=384,
num_heads=6,
num_layers=6,
batch_size=32,
learning_rate=3e-4,
num_epochs=5,
)
trainer = Trainer(config)
trainer.train()
if __name__ == "__main__":
main()Part 8: Text Generation
# generate.py
"""
Generate text using trained GPT model.
"""
import torch
import tiktoken
from src.gpt import GPT, GPTConfig
def load_model(checkpoint_path: str, device: str = "cuda") -> GPT:
"""Load trained model from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location=device)
config = checkpoint["config"]
model_config = GPTConfig(
vocab_size=config.vocab_size,
max_seq_len=config.max_seq_len,
d_model=config.d_model,
num_heads=config.num_heads,
num_layers=config.num_layers,
dropout=0.0, # Disable dropout for generation
)
model = GPT(model_config).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model
def generate_text(
model: GPT,
prompt: str,
max_tokens: int = 200,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
) -> str:
"""Generate text from prompt."""
device = next(model.parameters()).device
tokenizer = tiktoken.get_encoding("gpt2")
# Encode prompt
input_ids = torch.tensor([tokenizer.encode(prompt)], device=device)
# Generate
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)
# Decode
generated_text = tokenizer.decode(output_ids[0].tolist())
return generated_text
def interactive_generation():
"""Interactive text generation."""
device = "cuda" if torch.cuda.is_available() else "cpu"
model = load_model("checkpoints/best_model.pt", device)
print("GPT Text Generator (type 'quit' to exit)")
print("-" * 50)
while True:
prompt = input("\nPrompt: ")
if prompt.lower() == "quit":
break
generated = generate_text(
model,
prompt,
max_tokens=100,
temperature=0.8,
)
print(f"\nGenerated:\n{generated}")
if __name__ == "__main__":
interactive_generation()Modern Optimizations
Flash Attention Concept
# src/flash_attention.py
"""
Flash Attention concept (simplified implementation).
The full Flash Attention is more complex and uses CUDA kernels.
Key ideas:
1. Tile-based computation to fit in SRAM
2. Online softmax computation
3. Avoid materializing full attention matrix
"""
import torch
import torch.nn.functional as F
import math
def flash_attention_naive(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
block_size: int = 256,
) -> torch.Tensor:
"""
Simplified demonstration of Flash Attention concepts.
Real Flash Attention uses custom CUDA kernels for efficiency.
This is a pure PyTorch approximation for educational purposes.
Args:
query, key, value: (batch, heads, seq_len, d_k)
block_size: Size of each tile/block
Returns:
Output tensor (batch, heads, seq_len, d_k)
"""
batch, heads, seq_len, d_k = query.shape
scale = 1.0 / math.sqrt(d_k)
output = torch.zeros_like(query)
# Process in blocks to reduce memory
for i in range(0, seq_len, block_size):
i_end = min(i + block_size, seq_len)
q_block = query[:, :, i:i_end, :]
# Running softmax statistics
row_max = torch.full((batch, heads, i_end - i, 1), float('-inf'), device=query.device)
row_sum = torch.zeros((batch, heads, i_end - i, 1), device=query.device)
row_out = torch.zeros((batch, heads, i_end - i, d_k), device=query.device)
for j in range(0, seq_len, block_size):
j_end = min(j + block_size, seq_len)
k_block = key[:, :, j:j_end, :]
v_block = value[:, :, j:j_end, :]
# Compute attention scores for this block
scores = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale
# Online softmax: update max and sum
block_max = scores.max(dim=-1, keepdim=True).values
new_max = torch.maximum(row_max, block_max)
# Update running sum with new max
exp_old = torch.exp(row_max - new_max)
exp_new = torch.exp(scores - new_max)
new_sum = row_sum * exp_old + exp_new.sum(dim=-1, keepdim=True)
# Update output
row_out = row_out * exp_old + torch.matmul(exp_new, v_block)
row_max = new_max
row_sum = new_sum
# Normalize output
output[:, :, i:i_end, :] = row_out / row_sum
return output
# For production, use torch.nn.functional.scaled_dot_product_attention
# which has Flash Attention built-in (PyTorch 2.0+)
def efficient_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = True,
) -> torch.Tensor:
"""
Use PyTorch's optimized attention (includes Flash Attention when available).
"""
return F.scaled_dot_product_attention(
query, key, value,
is_causal=is_causal,
dropout_p=0.0,
)KV Cache for Efficient Generation
# src/kv_cache.py
"""
KV Cache implementation for efficient autoregressive generation.
"""
from typing import List, Tuple, Optional
import torch
class KVCache:
"""
Key-Value cache for efficient transformer generation.
During generation, we cache K and V from previous tokens
so we only need to compute attention for the new token.
"""
def __init__(
self,
num_layers: int,
max_batch_size: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype = torch.float16,
):
self.num_layers = num_layers
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
# Pre-allocate cache tensors
# Shape: (num_layers, 2, batch, heads, seq, head_dim)
# The "2" is for K and V
self.cache = torch.zeros(
(num_layers, 2, max_batch_size, num_heads, max_seq_len, head_dim),
device=device,
dtype=dtype,
)
# Track current sequence length for each batch item
self.seq_lengths = torch.zeros(max_batch_size, dtype=torch.long, device=device)
def update(
self,
layer_idx: int,
key: torch.Tensor,
value: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update cache with new K, V and return full cached K, V.
Args:
layer_idx: Index of the transformer layer
key: New key tensor (batch, heads, new_seq_len, head_dim)
value: New value tensor (batch, heads, new_seq_len, head_dim)
Returns:
Tuple of (cached_keys, cached_values) including new values
"""
batch_size, num_heads, new_seq_len, head_dim = key.shape
# Get current positions
start_pos = self.seq_lengths[:batch_size].max().item()
end_pos = start_pos + new_seq_len
# Store new K, V in cache
self.cache[layer_idx, 0, :batch_size, :, start_pos:end_pos, :] = key
self.cache[layer_idx, 1, :batch_size, :, start_pos:end_pos, :] = value
# Return all cached K, V up to current position
cached_k = self.cache[layer_idx, 0, :batch_size, :, :end_pos, :]
cached_v = self.cache[layer_idx, 1, :batch_size, :, :end_pos, :]
return cached_k, cached_v
def update_seq_length(self, batch_size: int, new_tokens: int):
"""Update sequence lengths after generation step."""
self.seq_lengths[:batch_size] += new_tokens
def reset(self, batch_indices: Optional[List[int]] = None):
"""Reset cache for specified batch indices or all."""
if batch_indices is None:
self.cache.zero_()
self.seq_lengths.zero_()
else:
for idx in batch_indices:
self.cache[:, :, idx, :, :, :].zero_()
self.seq_lengths[idx] = 0
def get_seq_length(self, batch_idx: int = 0) -> int:
"""Get current sequence length for a batch item."""
return self.seq_lengths[batch_idx].item()Performance Comparison
| Component | Original | Modern | Benefit |
|---|---|---|---|
| Position Encoding | Sinusoidal | RoPE | Better length extrapolation |
| Normalization | LayerNorm | RMSNorm | 10-15% faster |
| FFN | ReLU | SwiGLU | Better quality |
| Attention | Multi-Head | GQA | Lower memory, faster |
| Implementation | Naive | Flash Attention | 2-4x faster |
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Scaled Dot-Product Attention | softmax(QK^T/√d_k)V | Core mechanism - lets each position attend to all others |
| Multi-Head Attention | Run attention h times in parallel with different projections | Captures different types of relationships (syntax, semantics, etc.) |
| d_k Scaling | Divide by √d_k before softmax | Prevents softmax saturation with large d_k (gradients vanish) |
| Causal Mask | Lower triangular mask (future = -∞) | Ensures autoregressive property - can only see past tokens |
| RoPE | Rotary Position Embedding - encode position as rotation | Better length extrapolation than learned or sinusoidal |
| RMSNorm | Normalize by RMS only (no mean subtraction) | ~10-15% faster than LayerNorm with similar quality |
| SwiGLU | Swish(xW_gate) ⊙ (xW_up), then W_down | Better FFN quality, used in LLaMA/PaLM |
| KV Cache | Store K,V from previous tokens during generation | O(1) per new token vs O(n) - critical for fast inference |
| GQA | Grouped Query Attention - fewer K,V heads than Q heads | Reduces memory/compute while maintaining quality |
| Weight Tying | Share embedding weights with LM head | Reduces parameters, improves consistency |
Next Steps
After building your custom transformer:
- Scale up - Train larger models on more data
- Fine-tune - Adapt for specific tasks
- Optimize - Apply quantization and distillation
- Deploy - Serve with efficient inference
Resources
- Attention Is All You Need - Original paper
- The Illustrated Transformer - Visual guide
- GPT-2 Paper
- LLaMA Paper - Modern architecture
- Flash Attention - Efficient attention
- RoPE Paper - Rotary embeddings