Training SLM from Scratch
Pre-train your own small language model
Training SLM from Scratch
TL;DR
Train a custom SLM by building a data pipeline (collect, clean, dedupe), training a BPE tokenizer with SentencePiece, implementing a modern transformer with RoPE, GQA, RMSNorm, and SwiGLU, running a PyTorch training loop with gradient accumulation and mixed precision, then exporting to GGUF for deployment.
Build and train a small language model from the ground up, from data preparation to a working model.
Project Overview
| Aspect | Details |
|---|---|
| Difficulty | Advanced |
| Time | 4-5 days |
| Prerequisites | PyTorch, deep learning fundamentals, GPU access |
| Learning Outcomes | Data preparation, tokenizer training, architecture design, pre-training |
What You'll Learn
- Collect and preprocess training data
- Train custom BPE tokenizers with SentencePiece
- Implement transformer architecture from scratch
- Set up distributed training with DeepSpeed
- Monitor and evaluate model training
- Export to GGUF for deployment
Architecture Overview
┌─────────────────────────────────────────────────────────────────────────────┐
│ SLM TRAINING PIPELINE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ DATA PIPELINE │ │
│ │ │ │
│ │ Raw Text ──► Cleaning ──► Tokenizer Training ──► Encoded Data │ │
│ │ Data • HTML • SentencePiece • Token IDs │ │
│ │ • URLs • BPE vocab • Chunked │ │
│ │ • Dedupe • Special tokens • Packed │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ MODEL ARCHITECTURE │ │
│ │ │ │
│ │ Token Embeddings ──► Transformer Blocks ──► LM Head │ │
│ │ • Vocab → Hidden • RMSNorm • Hidden → Vocab │ │
│ │ • Self-Attention • Tied weights │ │
│ │ • SwiGLU FFN │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ TRAINING LOOP │ │
│ │ │ │
│ │ DataLoader ──► Forward ──► Loss ──► Backward ──► Optimizer │ │
│ │ • Batching • AMP • CE • Gradient • AdamW │ │
│ │ • Shuffling • Dropout • Clipping • Warmup │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ OUTPUT │ │
│ │ │ │
│ │ Checkpoints ────────► GGUF Export ────────► Deployment │ │
│ │ • PyTorch .pt • Quantized • llama.cpp │ │
│ │ • Best model • Metadata • Ollama │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Project Setup
Dependencies
# Create project directory
mkdir slm-training && cd slm-training
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Core training dependencies
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# Tokenizer and data
pip install sentencepiece tokenizers datasets
# Training utilities
pip install transformers accelerate deepspeed
# Monitoring and logging
pip install wandb tensorboard tqdm
# Export utilities
pip install safetensors ggufHardware Requirements
| Model Size | Parameters | Training VRAM | Recommended GPU |
|---|---|---|---|
| Tiny | ~50M | 4GB | RTX 3060 |
| Small | ~125M | 8GB | RTX 3080 |
| Medium | ~350M | 16GB | RTX 4090 |
| Large | ~1B | 24GB+ | A100 40GB |
Part 1: Data Collection and Preparation
Build a high-quality training dataset.
# data/collector.py
"""
Data collection and preprocessing for SLM training.
"""
import os
import json
import re
from pathlib import Path
from typing import Iterator, Optional
from dataclasses import dataclass
import hashlib
from datasets import load_dataset, Dataset, concatenate_datasets
from tqdm import tqdm
@dataclass
class DataConfig:
"""Configuration for data collection."""
output_dir: str = "./training_data"
min_length: int = 100 # Minimum characters per document
max_length: int = 100000 # Maximum characters per document
dedup: bool = True # Remove exact duplicates
language: str = "en" # Target language
class DataCollector:
"""
Collect and preprocess training data from various sources.
"""
def __init__(self, config: DataConfig):
self.config = config
self.output_dir = Path(config.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.seen_hashes = set()
def collect_from_huggingface(
self,
dataset_name: str,
split: str = "train",
text_column: str = "text",
max_samples: int = None
) -> Iterator[str]:
"""Collect data from HuggingFace datasets."""
print(f"Loading {dataset_name}...")
dataset = load_dataset(dataset_name, split=split, streaming=True)
count = 0
for item in tqdm(dataset, desc=f"Processing {dataset_name}"):
if max_samples and count >= max_samples:
break
text = item.get(text_column, "")
if text:
cleaned = self._clean_text(text)
if cleaned and self._should_include(cleaned):
count += 1
yield cleaned
def collect_from_files(self, directory: str, pattern: str = "*.txt") -> Iterator[str]:
"""Collect data from local text files."""
dir_path = Path(directory)
for file_path in tqdm(list(dir_path.glob(pattern)), desc="Processing files"):
try:
text = file_path.read_text(encoding="utf-8")
cleaned = self._clean_text(text)
if cleaned and self._should_include(cleaned):
yield cleaned
except Exception as e:
print(f"Error reading {file_path}: {e}")
def _clean_text(self, text: str) -> str:
"""Clean and normalize text."""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove HTML tags
text = re.sub(r'<[^>]+>', '', text)
# Remove URLs
text = re.sub(r'http[s]?://\S+', '', text)
# Normalize unicode
text = text.encode('utf-8', errors='ignore').decode('utf-8')
return text.strip()
def _should_include(self, text: str) -> bool:
"""Check if text should be included in dataset."""
# Length checks
if len(text) < self.config.min_length:
return False
if len(text) > self.config.max_length:
return False
# Deduplication
if self.config.dedup:
text_hash = hashlib.md5(text.encode()).hexdigest()
if text_hash in self.seen_hashes:
return False
self.seen_hashes.add(text_hash)
return True
def build_dataset(
self,
sources: list[dict],
output_name: str = "training_data"
) -> Dataset:
"""
Build a dataset from multiple sources.
Args:
sources: List of source configurations
output_name: Name for the output dataset
Returns:
HuggingFace Dataset
"""
all_texts = []
for source in sources:
source_type = source.get("type")
if source_type == "huggingface":
texts = list(self.collect_from_huggingface(
source["name"],
source.get("split", "train"),
source.get("text_column", "text"),
source.get("max_samples")
))
elif source_type == "files":
texts = list(self.collect_from_files(
source["directory"],
source.get("pattern", "*.txt")
))
else:
print(f"Unknown source type: {source_type}")
continue
print(f"Collected {len(texts)} documents from {source.get('name', source.get('directory'))}")
all_texts.extend(texts)
print(f"\nTotal documents: {len(all_texts)}")
print(f"Total characters: {sum(len(t) for t in all_texts):,}")
# Create dataset
dataset = Dataset.from_dict({"text": all_texts})
# Save dataset
output_path = self.output_dir / output_name
dataset.save_to_disk(str(output_path))
print(f"Dataset saved to {output_path}")
return dataset
class DataProcessor:
"""Process and chunk data for training."""
def __init__(self, tokenizer, max_length: int = 2048):
self.tokenizer = tokenizer
self.max_length = max_length
def chunk_and_tokenize(
self,
dataset: Dataset,
num_proc: int = 4
) -> Dataset:
"""Chunk documents and tokenize for training."""
def tokenize_function(examples):
# Tokenize all texts
tokenized = self.tokenizer(
examples["text"],
truncation=False,
add_special_tokens=False
)
# Concatenate all tokens
all_tokens = []
for tokens in tokenized["input_ids"]:
all_tokens.extend(tokens)
all_tokens.append(self.tokenizer.eos_token_id)
# Chunk into max_length sequences
chunks = []
for i in range(0, len(all_tokens) - self.max_length, self.max_length):
chunks.append(all_tokens[i:i + self.max_length])
return {
"input_ids": chunks,
"attention_mask": [[1] * len(chunk) for chunk in chunks]
}
# Process in batches
processed = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset.column_names,
num_proc=num_proc,
desc="Tokenizing"
)
return processed
# Example data sources configuration
EXAMPLE_SOURCES = [
{
"type": "huggingface",
"name": "roneneldan/TinyStories",
"split": "train",
"text_column": "text",
"max_samples": 100000
},
{
"type": "huggingface",
"name": "wikitext",
"name_config": "wikitext-103-raw-v1",
"split": "train",
"text_column": "text",
"max_samples": 50000
}
]
if __name__ == "__main__":
config = DataConfig(output_dir="./training_data")
collector = DataCollector(config)
# Use a smaller subset for demo
sources = [
{
"type": "huggingface",
"name": "roneneldan/TinyStories",
"split": "train",
"text_column": "text",
"max_samples": 10000
}
]
dataset = collector.build_dataset(sources, "tiny_stories_10k")
print(f"Dataset created with {len(dataset)} documents")★ Insight ─────────────────────────────────────
Data Quality Matters Most: For small models, data quality is even more critical than quantity. A 100M parameter model trained on 10B high-quality tokens often outperforms the same model trained on 100B low-quality tokens. TinyStories demonstrates this—simple, clean narratives produce surprisingly capable small models.
─────────────────────────────────────────────────
Understanding the Data Pipeline:
┌─────────────────────────────────────────────────────────────────────────────┐
│ WHY EACH PREPROCESSING STEP MATTERS │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Raw Web Data: After Cleaning: │
│ ┌─────────────────────────────┐ ┌─────────────────────────────────┐ │
│ │ <html><head>...</head> │ │ Machine learning is a branch │ │
│ │ <body>Machine learning is │ ───► │ of artificial intelligence │ │
│ │ a branch of artificial │ │ that focuses on building │ │
│ │ intelligence... http://... │ │ applications that learn. │ │
│ │ </body></html> │ └─────────────────────────────────┘ │
│ └─────────────────────────────┘ │
│ │
│ Deduplication (Hash-based): │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Doc 1: "The cat sat on the mat" → MD5: abc123 → Keep │ │
│ │ Doc 2: "Dogs are loyal pets" → MD5: def456 → Keep │ │
│ │ Doc 3: "The cat sat on the mat" → MD5: abc123 → SKIP (duplicate!) │ │
│ │ │ │
│ │ Why? Duplicates cause: │ │
│ │ • Model memorizes instead of generalizing │ │
│ │ • Overrepresentation of certain patterns │ │
│ │ • Wasted compute on redundant data │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ Length Filtering: │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ min_length=100: Filters out: │ │
│ │ • Empty/near-empty docs │ │
│ │ • Single sentences (not enough context) │ │
│ │ • Metadata/headers without content │ │
│ │ │ │
│ │ max_length=100000: Filters out: │ │
│ │ • Concatenated mega-documents │ │
│ │ • Log files/data dumps │ │
│ │ • Memory-heavy outliers │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Common Data Sources for SLM Training:
| Source | Type | Best For | Quality |
|---|---|---|---|
| TinyStories | HuggingFace | Story generation SLMs | Very high |
| Wikitext | HuggingFace | General knowledge | High |
| Books/Gutenberg | Files | Narrative ability | High |
| StackOverflow | API | Code/technical | Medium |
| Common Crawl | S3 | Scale | Varies (needs filtering) |
Part 2: Tokenizer Training
Train a custom BPE tokenizer with SentencePiece.
# tokenizer/train_tokenizer.py
"""
Train a custom BPE tokenizer using SentencePiece.
"""
import os
from pathlib import Path
from typing import Iterator, Optional
import sentencepiece as spm
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
class TokenizerTrainer:
"""
Train custom tokenizers for SLM pre-training.
"""
def __init__(self, output_dir: str = "./tokenizer"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def train_sentencepiece(
self,
input_files: list[str],
vocab_size: int = 32000,
model_type: str = "bpe", # unigram, bpe, char, word
character_coverage: float = 0.9995,
model_prefix: str = "tokenizer"
) -> str:
"""
Train a SentencePiece tokenizer.
Args:
input_files: List of text files for training
vocab_size: Target vocabulary size
model_type: Type of model (bpe recommended for SLMs)
character_coverage: Fraction of characters to cover
model_prefix: Prefix for output files
Returns:
Path to trained model
"""
output_prefix = str(self.output_dir / model_prefix)
# Prepare training arguments
train_args = [
f"--input={','.join(input_files)}",
f"--model_prefix={output_prefix}",
f"--vocab_size={vocab_size}",
f"--model_type={model_type}",
f"--character_coverage={character_coverage}",
"--pad_id=0",
"--unk_id=1",
"--bos_id=2",
"--eos_id=3",
"--pad_piece=<pad>",
"--unk_piece=<unk>",
"--bos_piece=<s>",
"--eos_piece=</s>",
"--user_defined_symbols=<|im_start|>,<|im_end|>", # Chat tokens
"--byte_fallback=true", # Handle unknown bytes
"--split_digits=true", # Split numbers
"--allow_whitespace_only_pieces=true",
]
# Train
spm.SentencePieceTrainer.Train(" ".join(train_args))
print(f"Tokenizer saved to {output_prefix}.model")
return f"{output_prefix}.model"
def train_huggingface_bpe(
self,
text_iterator: Iterator[str],
vocab_size: int = 32000,
min_frequency: int = 2,
model_name: str = "tokenizer"
) -> Tokenizer:
"""
Train a HuggingFace BPE tokenizer.
Args:
text_iterator: Iterator yielding text samples
vocab_size: Target vocabulary size
min_frequency: Minimum token frequency
model_name: Name for the tokenizer
Returns:
Trained Tokenizer
"""
# Initialize tokenizer
tokenizer = Tokenizer(models.BPE())
# Pre-tokenization
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
# Decoder
tokenizer.decoder = decoders.ByteLevel()
# Trainer
trainer = trainers.BpeTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=[
"<pad>",
"<unk>",
"<s>",
"</s>",
"<|im_start|>",
"<|im_end|>"
]
)
# Train
tokenizer.train_from_iterator(text_iterator, trainer=trainer)
# Save
output_path = self.output_dir / f"{model_name}.json"
tokenizer.save(str(output_path))
print(f"Tokenizer saved to {output_path}")
return tokenizer
class SLMTokenizer:
"""Wrapper for SLM tokenizer."""
def __init__(self, model_path: str):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(model_path)
# Special token IDs
self.pad_token_id = self.sp.pad_id()
self.unk_token_id = self.sp.unk_id()
self.bos_token_id = self.sp.bos_id()
self.eos_token_id = self.sp.eos_id()
@property
def vocab_size(self) -> int:
return self.sp.GetPieceSize()
def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> list[int]:
"""Encode text to token IDs."""
tokens = self.sp.EncodeAsIds(text)
if add_bos:
tokens = [self.bos_token_id] + tokens
if add_eos:
tokens = tokens + [self.eos_token_id]
return tokens
def decode(self, tokens: list[int]) -> str:
"""Decode token IDs to text."""
return self.sp.DecodeIds(tokens)
def tokenize(self, text: str) -> list[str]:
"""Tokenize text to pieces."""
return self.sp.EncodeAsPieces(text)
def __call__(self, texts, **kwargs):
"""Batch encoding for compatibility."""
if isinstance(texts, str):
texts = [texts]
encoded = {
"input_ids": [self.encode(t) for t in texts]
}
return encoded
# Example usage
if __name__ == "__main__":
# Create sample data
sample_texts = [
"Once upon a time, there was a small robot who wanted to learn.",
"The robot read many books and practiced every day.",
"Eventually, the robot became very smart and helpful.",
"It helped people solve problems and answer questions.",
]
# Save sample data
os.makedirs("./sample_data", exist_ok=True)
with open("./sample_data/stories.txt", "w") as f:
f.write("\n".join(sample_texts * 1000)) # Repeat for training
# Train tokenizer
trainer = TokenizerTrainer(output_dir="./tokenizer")
model_path = trainer.train_sentencepiece(
input_files=["./sample_data/stories.txt"],
vocab_size=1000, # Small for demo
model_type="bpe"
)
# Test tokenizer
tokenizer = SLMTokenizer(model_path)
test_text = "The robot learned to code."
tokens = tokenizer.encode(test_text)
decoded = tokenizer.decode(tokens)
print(f"\nOriginal: {test_text}")
print(f"Tokens: {tokens}")
print(f"Decoded: {decoded}")
print(f"Pieces: {tokenizer.tokenize(test_text)}")Part 3: Model Architecture
Implement a modern transformer architecture.
# model/architecture.py
"""
Small Language Model architecture implementation.
"""
import math
from typing import Optional, Tuple
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class SLMConfig:
"""Configuration for Small Language Model."""
vocab_size: int = 32000
hidden_size: int = 768
intermediate_size: int = 2048
num_hidden_layers: int = 12
num_attention_heads: int = 12
num_key_value_heads: int = 4 # For GQA
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
attention_dropout: float = 0.0
hidden_dropout: float = 0.0
tie_word_embeddings: bool = True
@property
def head_dim(self) -> int:
return self.hidden_size // self.num_attention_heads
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# Build cache
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.max_seq_len:
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half of the hidden dims."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to queries and keys."""
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SLMAttention(nn.Module):
"""Multi-head attention with grouped query attention (GQA) support."""
def __init__(self, config: SLMConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_kv_groups = self.num_heads // self.num_kv_heads
# Projections
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
# Rotary embeddings
self.rotary_emb = RotaryEmbedding(
self.head_dim,
max_seq_len=config.max_position_embeddings,
base=config.rope_theta
)
self.dropout = nn.Dropout(config.attention_dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
batch_size, seq_len, _ = hidden_states.shape
# Project to Q, K, V
query = self.q_proj(hidden_states)
key = self.k_proj(hidden_states)
value = self.v_proj(hidden_states)
# Reshape for attention
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
cos, sin = self.rotary_emb(hidden_states, seq_len)
query, key = apply_rotary_pos_emb(query, key, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0))
# Handle past key-values (for generation)
if past_key_value is not None:
key = torch.cat([past_key_value[0], key], dim=2)
value = torch.cat([past_key_value[1], value], dim=2)
past_key_value = (key, value) if use_cache else None
# Repeat K/V for GQA
if self.num_kv_groups > 1:
key = key.repeat_interleave(self.num_kv_groups, dim=1)
value = value.repeat_interleave(self.num_kv_groups, dim=1)
# Compute attention
attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Apply attention mask
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, value)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, past_key_value
class SLMMlp(nn.Module):
"""SwiGLU MLP block."""
def __init__(self, config: SLMConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
return self.dropout(self.down_proj(gate * up))
class SLMDecoderLayer(nn.Module):
"""Single transformer decoder layer."""
def __init__(self, config: SLMConfig):
super().__init__()
self.self_attn = SLMAttention(config)
self.mlp = SLMMlp(config)
self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Self attention with residual
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, past_key_value = self.self_attn(
hidden_states,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache
)
hidden_states = residual + hidden_states
# MLP with residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, past_key_value
class SLMModel(nn.Module):
"""Small Language Model for pre-training."""
def __init__(self, config: SLMConfig):
super().__init__()
self.config = config
# Embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# Transformer layers
self.layers = nn.ModuleList([
SLMDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
# Final norm
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[list] = None,
use_cache: bool = False
) -> Tuple[torch.Tensor, Optional[list]]:
batch_size, seq_len = input_ids.shape
# Get embeddings
hidden_states = self.embed_tokens(input_ids)
# Create causal mask
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device)
# Expand to 4D for attention
causal_mask = self._make_causal_mask(batch_size, seq_len, input_ids.device)
# Apply layers
new_past_key_values = [] if use_cache else None
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
hidden_states, new_past_kv = layer(
hidden_states,
attention_mask=causal_mask,
past_key_value=past_kv,
use_cache=use_cache
)
if use_cache:
new_past_key_values.append(new_past_kv)
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values
def _make_causal_mask(
self,
batch_size: int,
seq_len: int,
device: torch.device
) -> torch.Tensor:
mask = torch.full((seq_len, seq_len), float("-inf"), device=device)
mask = torch.triu(mask, diagonal=1)
mask = mask.unsqueeze(0).unsqueeze(0)
return mask.expand(batch_size, 1, seq_len, seq_len)
class SLMForCausalLM(nn.Module):
"""SLM with language modeling head."""
def __init__(self, config: SLMConfig):
super().__init__()
self.config = config
self.model = SLMModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie embeddings
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[list] = None,
use_cache: bool = False
) -> dict:
hidden_states, past_key_values = self.model(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache
)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift for next token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100
)
return {
"loss": loss,
"logits": logits,
"past_key_values": past_key_values
}
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.9
) -> torch.Tensor:
"""Generate text autoregressively."""
past_key_values = None
for _ in range(max_new_tokens):
# Forward pass
outputs = self.forward(
input_ids if past_key_values is None else input_ids[:, -1:],
past_key_values=past_key_values,
use_cache=True
)
logits = outputs["logits"][:, -1, :]
past_key_values = outputs["past_key_values"]
# Temperature scaling
if temperature > 0:
logits = logits / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float("-inf")
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float("-inf")
# Sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
# Check for EOS
if next_token.item() == self.config.vocab_size - 1: # Assuming EOS is last
break
return input_ids
# Model size configurations
MODEL_CONFIGS = {
"tiny": SLMConfig(
hidden_size=256,
intermediate_size=512,
num_hidden_layers=6,
num_attention_heads=8,
num_key_value_heads=2,
),
"small": SLMConfig(
hidden_size=512,
intermediate_size=1024,
num_hidden_layers=8,
num_attention_heads=8,
num_key_value_heads=4,
),
"medium": SLMConfig(
hidden_size=768,
intermediate_size=2048,
num_hidden_layers=12,
num_attention_heads=12,
num_key_value_heads=4,
),
"base": SLMConfig(
hidden_size=1024,
intermediate_size=2816,
num_hidden_layers=16,
num_attention_heads=16,
num_key_value_heads=4,
),
}
def count_parameters(model: nn.Module) -> int:
"""Count trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
# Test model creation
for name, config in MODEL_CONFIGS.items():
model = SLMForCausalLM(config)
params = count_parameters(model)
print(f"{name}: {params:,} parameters ({params / 1e6:.1f}M)")
# Test forward pass
config = MODEL_CONFIGS["tiny"]
model = SLMForCausalLM(config)
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
labels = input_ids.clone()
outputs = model(input_ids, labels=labels)
print(f"\nLoss: {outputs['loss'].item():.4f}")
print(f"Logits shape: {outputs['logits'].shape}")★ Insight ─────────────────────────────────────
Modern SLM Architecture Choices: This implementation uses several modern techniques: (1) RMSNorm instead of LayerNorm for faster training, (2) RoPE for position encoding which generalizes better than learned positional embeddings, (3) SwiGLU activation for the FFN which improves quality, (4) Grouped Query Attention (GQA) to reduce memory during inference while maintaining quality.
─────────────────────────────────────────────────
Part 4: Training Loop
Implement the training pipeline with distributed training support.
# training/trainer.py
"""
Training loop for SLM pre-training.
"""
import os
import math
import time
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
@dataclass
class TrainingConfig:
"""Training configuration."""
# Data
batch_size: int = 32
gradient_accumulation_steps: int = 4
max_seq_length: int = 2048
# Optimizer
learning_rate: float = 3e-4
weight_decay: float = 0.1
beta1: float = 0.9
beta2: float = 0.95
max_grad_norm: float = 1.0
# Schedule
warmup_steps: int = 1000
total_steps: int = 100000
lr_scheduler: str = "cosine"
# Training
mixed_precision: bool = True
compile_model: bool = False # PyTorch 2.0 compile
# Checkpointing
checkpoint_dir: str = "./checkpoints"
save_every: int = 1000
eval_every: int = 500
# Logging
log_every: int = 10
wandb_project: Optional[str] = None
wandb_run_name: Optional[str] = None
class TokenDataset(Dataset):
"""Dataset for pre-tokenized data."""
def __init__(self, data_path: str, seq_length: int):
self.seq_length = seq_length
self.data = torch.load(data_path)
def __len__(self):
return len(self.data) - self.seq_length
def __getitem__(self, idx):
chunk = self.data[idx:idx + self.seq_length + 1]
return {
"input_ids": chunk[:-1],
"labels": chunk[1:]
}
class SLMTrainer:
"""Trainer for Small Language Models."""
def __init__(
self,
model: nn.Module,
config: TrainingConfig,
train_dataset: Dataset,
eval_dataset: Optional[Dataset] = None,
tokenizer = None
):
self.model = model
self.config = config
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
# Setup device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Compile model if requested
if config.compile_model and hasattr(torch, "compile"):
self.model = torch.compile(self.model)
# Setup optimizer
self.optimizer = self._create_optimizer()
# Setup scheduler
self.scheduler = self._create_scheduler()
# Mixed precision
self.scaler = GradScaler() if config.mixed_precision else None
# Checkpointing
self.checkpoint_dir = Path(config.checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
# Tracking
self.global_step = 0
self.best_eval_loss = float("inf")
# Wandb
if config.wandb_project and WANDB_AVAILABLE:
wandb.init(
project=config.wandb_project,
name=config.wandb_run_name,
config=vars(config)
)
def _create_optimizer(self) -> AdamW:
"""Create optimizer with weight decay."""
# Separate parameters that should/shouldn't have weight decay
decay_params = []
no_decay_params = []
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
if "bias" in name or "norm" in name:
no_decay_params.append(param)
else:
decay_params.append(param)
param_groups = [
{"params": decay_params, "weight_decay": self.config.weight_decay},
{"params": no_decay_params, "weight_decay": 0.0}
]
return AdamW(
param_groups,
lr=self.config.learning_rate,
betas=(self.config.beta1, self.config.beta2)
)
def _create_scheduler(self):
"""Create learning rate scheduler with warmup."""
def lr_lambda(step):
if step < self.config.warmup_steps:
return step / self.config.warmup_steps
else:
progress = (step - self.config.warmup_steps) / (
self.config.total_steps - self.config.warmup_steps
)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
def train(self):
"""Run training loop."""
# Create data loader
train_loader = DataLoader(
self.train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
# Training loop
self.model.train()
accumulated_loss = 0.0
num_accumulated = 0
progress_bar = tqdm(total=self.config.total_steps, desc="Training")
while self.global_step < self.config.total_steps:
for batch in train_loader:
if self.global_step >= self.config.total_steps:
break
# Move to device
input_ids = batch["input_ids"].to(self.device)
labels = batch["labels"].to(self.device)
# Forward pass
with autocast(enabled=self.config.mixed_precision):
outputs = self.model(input_ids, labels=labels)
loss = outputs["loss"] / self.config.gradient_accumulation_steps
# Backward pass
if self.scaler:
self.scaler.scale(loss).backward()
else:
loss.backward()
accumulated_loss += loss.item()
num_accumulated += 1
# Gradient accumulation step
if num_accumulated >= self.config.gradient_accumulation_steps:
# Gradient clipping
if self.scaler:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# Optimizer step
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
progress_bar.update(1)
# Logging
if self.global_step % self.config.log_every == 0:
avg_loss = accumulated_loss * self.config.gradient_accumulation_steps / num_accumulated
lr = self.scheduler.get_last_lr()[0]
progress_bar.set_postfix({
"loss": f"{avg_loss:.4f}",
"lr": f"{lr:.2e}"
})
if WANDB_AVAILABLE and self.config.wandb_project:
wandb.log({
"train/loss": avg_loss,
"train/learning_rate": lr,
"train/step": self.global_step
})
# Evaluation
if self.global_step % self.config.eval_every == 0:
self._evaluate()
# Checkpointing
if self.global_step % self.config.save_every == 0:
self._save_checkpoint()
accumulated_loss = 0.0
num_accumulated = 0
progress_bar.close()
# Final save
self._save_checkpoint(final=True)
def _evaluate(self):
"""Run evaluation."""
if self.eval_dataset is None:
return
self.model.eval()
total_loss = 0.0
num_batches = 0
eval_loader = DataLoader(
self.eval_dataset,
batch_size=self.config.batch_size,
shuffle=False
)
with torch.no_grad():
for batch in tqdm(eval_loader, desc="Evaluating", leave=False):
input_ids = batch["input_ids"].to(self.device)
labels = batch["labels"].to(self.device)
outputs = self.model(input_ids, labels=labels)
total_loss += outputs["loss"].item()
num_batches += 1
avg_loss = total_loss / num_batches
perplexity = math.exp(avg_loss)
print(f"\nEval Loss: {avg_loss:.4f}, Perplexity: {perplexity:.2f}")
if WANDB_AVAILABLE and self.config.wandb_project:
wandb.log({
"eval/loss": avg_loss,
"eval/perplexity": perplexity,
"eval/step": self.global_step
})
# Save best model
if avg_loss < self.best_eval_loss:
self.best_eval_loss = avg_loss
self._save_checkpoint(best=True)
self.model.train()
def _save_checkpoint(self, best: bool = False, final: bool = False):
"""Save model checkpoint."""
if best:
path = self.checkpoint_dir / "best_model.pt"
elif final:
path = self.checkpoint_dir / "final_model.pt"
else:
path = self.checkpoint_dir / f"checkpoint_{self.global_step}.pt"
checkpoint = {
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"global_step": self.global_step,
"config": self.model.config if hasattr(self.model, "config") else None
}
torch.save(checkpoint, path)
print(f"Saved checkpoint to {path}")
def load_checkpoint(self, path: str):
"""Load model checkpoint."""
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.global_step = checkpoint["global_step"]
print(f"Loaded checkpoint from {path} at step {self.global_step}")
# Training script
if __name__ == "__main__":
from model.architecture import SLMForCausalLM, MODEL_CONFIGS
# Create model
config = MODEL_CONFIGS["tiny"]
model = SLMForCausalLM(config)
# Create dummy dataset for testing
class DummyDataset(Dataset):
def __init__(self, size: int, seq_len: int, vocab_size: int):
self.size = size
self.seq_len = seq_len
self.vocab_size = vocab_size
def __len__(self):
return self.size
def __getitem__(self, idx):
tokens = torch.randint(0, self.vocab_size, (self.seq_len,))
return {"input_ids": tokens, "labels": tokens}
train_dataset = DummyDataset(10000, 512, config.vocab_size)
eval_dataset = DummyDataset(1000, 512, config.vocab_size)
# Training config
train_config = TrainingConfig(
batch_size=4,
gradient_accumulation_steps=2,
learning_rate=1e-4,
total_steps=100,
warmup_steps=10,
log_every=5,
eval_every=50,
save_every=100
)
# Create trainer
trainer = SLMTrainer(
model=model,
config=train_config,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
# Train
trainer.train()Part 5: Export and Deployment
Convert trained models to GGUF for deployment.
# export/to_gguf.py
"""
Export trained SLM to GGUF format for llama.cpp deployment.
"""
import struct
import numpy as np
from pathlib import Path
from typing import Any
import torch
class GGUFWriter:
"""
Write models to GGUF format.
GGUF is the format used by llama.cpp for efficient inference.
"""
GGUF_MAGIC = 0x46554747 # "GGUF"
GGUF_VERSION = 3
# GGML types
GGML_TYPE_F32 = 0
GGML_TYPE_F16 = 1
GGML_TYPE_Q4_0 = 2
GGML_TYPE_Q8_0 = 8
def __init__(self, output_path: str):
self.output_path = Path(output_path)
self.kv_data = {}
self.tensors = []
def add_architecture(self, arch: str):
"""Set model architecture."""
self.kv_data["general.architecture"] = arch
def add_name(self, name: str):
"""Set model name."""
self.kv_data["general.name"] = name
def add_uint32(self, key: str, value: int):
"""Add uint32 metadata."""
self.kv_data[key] = ("uint32", value)
def add_float32(self, key: str, value: float):
"""Add float32 metadata."""
self.kv_data[key] = ("float32", value)
def add_string(self, key: str, value: str):
"""Add string metadata."""
self.kv_data[key] = ("string", value)
def add_tensor(self, name: str, tensor: np.ndarray, quantize: bool = False):
"""Add tensor data."""
if quantize and tensor.dtype == np.float32:
# Simple quantization to Q8_0
tensor, scale = self._quantize_q8_0(tensor)
self.tensors.append({
"name": name,
"data": tensor,
"type": self.GGML_TYPE_Q8_0,
"scale": scale
})
else:
dtype_map = {
np.float32: self.GGML_TYPE_F32,
np.float16: self.GGML_TYPE_F16,
}
self.tensors.append({
"name": name,
"data": tensor,
"type": dtype_map.get(tensor.dtype, self.GGML_TYPE_F32)
})
def _quantize_q8_0(self, tensor: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Quantize to Q8_0 format (int8 with block-wise scaling)."""
block_size = 32
shape = tensor.shape
flat = tensor.flatten()
# Pad to multiple of block_size
padded_len = ((len(flat) + block_size - 1) // block_size) * block_size
padded = np.zeros(padded_len, dtype=np.float32)
padded[:len(flat)] = flat
# Reshape into blocks
blocks = padded.reshape(-1, block_size)
# Compute scales
scales = np.max(np.abs(blocks), axis=1, keepdims=True) / 127.0
scales = np.where(scales == 0, 1.0, scales)
# Quantize
quantized = np.round(blocks / scales).astype(np.int8)
return quantized.flatten()[:len(flat)].reshape(shape), scales.flatten()
def write(self):
"""Write GGUF file."""
with open(self.output_path, "wb") as f:
# Write header
f.write(struct.pack("<I", self.GGUF_MAGIC))
f.write(struct.pack("<I", self.GGUF_VERSION))
f.write(struct.pack("<Q", len(self.tensors))) # n_tensors
f.write(struct.pack("<Q", len(self.kv_data))) # n_kv
# Write KV pairs
for key, value in self.kv_data.items():
self._write_string(f, key)
if isinstance(value, tuple):
dtype, val = value
if dtype == "uint32":
f.write(struct.pack("<I", 4)) # type
f.write(struct.pack("<I", val))
elif dtype == "float32":
f.write(struct.pack("<I", 6)) # type
f.write(struct.pack("<f", val))
elif dtype == "string":
f.write(struct.pack("<I", 8)) # type
self._write_string(f, val)
else:
f.write(struct.pack("<I", 8)) # string type
self._write_string(f, str(value))
# Write tensor info
for tensor in self.tensors:
self._write_string(f, tensor["name"])
f.write(struct.pack("<I", len(tensor["data"].shape)))
for dim in tensor["data"].shape:
f.write(struct.pack("<Q", dim))
f.write(struct.pack("<I", tensor["type"]))
f.write(struct.pack("<Q", 0)) # offset (will be updated)
# Align to 32 bytes
current_pos = f.tell()
padding = (32 - (current_pos % 32)) % 32
f.write(b'\x00' * padding)
# Write tensor data
for tensor in self.tensors:
f.write(tensor["data"].tobytes())
# Align each tensor
current_pos = f.tell()
padding = (32 - (current_pos % 32)) % 32
f.write(b'\x00' * padding)
print(f"Wrote GGUF to {self.output_path}")
def _write_string(self, f, s: str):
"""Write length-prefixed string."""
encoded = s.encode("utf-8")
f.write(struct.pack("<Q", len(encoded)))
f.write(encoded)
def export_slm_to_gguf(
model,
tokenizer,
output_path: str,
quantize: bool = True
):
"""
Export SLM model to GGUF format.
Args:
model: Trained SLM model
tokenizer: Tokenizer
output_path: Output file path
quantize: Whether to quantize weights
"""
writer = GGUFWriter(output_path)
# Add metadata
config = model.config
writer.add_architecture("slm")
writer.add_name("Custom SLM")
writer.add_uint32("slm.vocab_size", config.vocab_size)
writer.add_uint32("slm.hidden_size", config.hidden_size)
writer.add_uint32("slm.num_hidden_layers", config.num_hidden_layers)
writer.add_uint32("slm.num_attention_heads", config.num_attention_heads)
writer.add_uint32("slm.num_key_value_heads", config.num_key_value_heads)
writer.add_uint32("slm.max_position_embeddings", config.max_position_embeddings)
writer.add_float32("slm.rms_norm_eps", config.rms_norm_eps)
writer.add_float32("slm.rope_theta", config.rope_theta)
# Add tokenizer info
writer.add_uint32("tokenizer.ggml.bos_token_id", tokenizer.bos_token_id)
writer.add_uint32("tokenizer.ggml.eos_token_id", tokenizer.eos_token_id)
writer.add_uint32("tokenizer.ggml.pad_token_id", tokenizer.pad_token_id)
# Export weights
state_dict = model.state_dict()
# Map weight names to GGUF format
weight_map = {
"model.embed_tokens.weight": "token_embd.weight",
"model.norm.weight": "output_norm.weight",
"lm_head.weight": "output.weight",
}
for name, tensor in state_dict.items():
np_tensor = tensor.cpu().numpy()
# Map layer weights
if "layers." in name:
layer_num = name.split(".")[2]
suffix = ".".join(name.split(".")[3:])
layer_map = {
"self_attn.q_proj.weight": f"blk.{layer_num}.attn_q.weight",
"self_attn.k_proj.weight": f"blk.{layer_num}.attn_k.weight",
"self_attn.v_proj.weight": f"blk.{layer_num}.attn_v.weight",
"self_attn.o_proj.weight": f"blk.{layer_num}.attn_output.weight",
"mlp.gate_proj.weight": f"blk.{layer_num}.ffn_gate.weight",
"mlp.up_proj.weight": f"blk.{layer_num}.ffn_up.weight",
"mlp.down_proj.weight": f"blk.{layer_num}.ffn_down.weight",
"input_layernorm.weight": f"blk.{layer_num}.attn_norm.weight",
"post_attention_layernorm.weight": f"blk.{layer_num}.ffn_norm.weight",
}
if suffix in layer_map:
gguf_name = layer_map[suffix]
else:
gguf_name = name.replace(".", "_")
elif name in weight_map:
gguf_name = weight_map[name]
else:
gguf_name = name.replace(".", "_")
# Add tensor (quantize large matrices)
should_quantize = quantize and np_tensor.ndim == 2 and np_tensor.size > 1024
writer.add_tensor(gguf_name, np_tensor, quantize=should_quantize)
writer.write()
print(f"Exported model to {output_path}")
# Example usage
if __name__ == "__main__":
from model.architecture import SLMForCausalLM, MODEL_CONFIGS
from tokenizer.train_tokenizer import SLMTokenizer
# Create a small model for demo
config = MODEL_CONFIGS["tiny"]
model = SLMForCausalLM(config)
# Mock tokenizer
class MockTokenizer:
bos_token_id = 2
eos_token_id = 3
pad_token_id = 0
tokenizer = MockTokenizer()
# Export
export_slm_to_gguf(
model=model,
tokenizer=tokenizer,
output_path="./custom_slm.gguf",
quantize=True
)Exercises
Exercise 1: Custom Dataset
Create a training dataset from:
- A specific domain (medical, legal, code)
- Multiple languages
- Conversation data
Train a tokenizer optimized for your domain.
Exercise 2: Architecture Experiments
Modify the architecture to experiment with:
- Different attention patterns (sliding window)
- Alternative activations (GELU vs SwiGLU)
- Varying depth vs width tradeoffs
Benchmark each variant on your target tasks.
Exercise 3: Training Optimization
Implement and compare:
- Different learning rate schedules
- Various optimizers (AdamW vs AdaFactor)
- Gradient checkpointing for memory efficiency
- Mixed precision with different dtypes
Exercise 4: Model Evaluation
Build an evaluation suite that measures:
- Perplexity on held-out data
- Performance on downstream tasks
- Generation quality
- Inference speed
Summary
You've learned to train SLMs from scratch:
- Data Pipeline: Collection, cleaning, and tokenization
- Custom Tokenizer: BPE training with SentencePiece
- Model Architecture: Modern transformer with RoPE, GQA, SwiGLU
- Training Loop: Distributed training with gradient accumulation
- Export: GGUF conversion for deployment
Key insights:
- Data quality matters more than quantity for small models
- Modern architecture choices (RoPE, GQA, RMSNorm) improve efficiency
- Warmup and proper learning rate scheduling are critical
- Regular checkpointing enables experimentation and recovery
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Data Collection | Gathering and preprocessing text from multiple sources | Data quality determines model capability more than quantity |
| BPE Tokenizer | Byte Pair Encoding splits text into subword units | Efficient vocabulary handles unseen words via subword composition |
| SentencePiece | Language-agnostic tokenizer training library | Treats input as raw text, no pre-tokenization needed |
| RoPE | Rotary Position Embedding encodes positions via rotation | Better length generalization than learned positional embeddings |
| GQA | Grouped Query Attention shares K/V heads across Q heads | Reduces memory during inference while maintaining quality |
| RMSNorm | Root Mean Square normalization without centering | Faster than LayerNorm with comparable training stability |
| SwiGLU | Swish-Gated Linear Unit activation for FFN | Improves model quality over standard GELU activation |
| Gradient Accumulation | Sum gradients over multiple mini-batches | Enables large effective batch sizes on limited GPU memory |
| Mixed Precision | Use FP16/BF16 for forward pass, FP32 for accumulation | 2x memory reduction and faster training with maintained precision |
| GGUF Export | llama.cpp's model format with quantization support | Enables efficient CPU inference and wide deployment compatibility |
Next Steps
- Speculative Decoding - Accelerate inference with draft models
- Production SLM System - Deploy your model at scale