Knowledge Distillation
Compress large models into smaller, faster versions
Knowledge Distillation
TL;DR
Train smaller "student" models by mimicking larger "teacher" models. The key insight: soft probability distributions (e.g., cat: 0.65, tiger: 0.25) contain "dark knowledge" about class relationships that hard labels (cat: 1.0) discard. Temperature scaling (T=4) softens outputs to reveal this information.
Train smaller student models using knowledge from larger teacher models.
Overview
| Difficulty | Intermediate |
| Time | ~6 hours |
| Prerequisites | PyTorch, neural network training |
| Learning Outcomes | Teacher-student training, soft labels, feature distillation |
Introduction
Knowledge distillation compresses large, accurate "teacher" models into smaller, efficient "student" models while maintaining most of the teacher's performance. This enables deployment on resource-constrained devices.
┌─────────────────────────────────────────────────────────────────────────────┐
│ Knowledge Distillation Flow │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Teacher Model (Frozen) Student Model (Training) │
│ ────────────────────── ───────────────────────── │
│ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Input │ │ Input │ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Large Model │ │ Small Model │ │
│ │ BERT-Large │ │ DistilBERT │ │
│ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────┐ Distillation ┌─────────────┐ │
│ │ Soft Labels │ ─ ─ ─ Loss ─ ─ ─ ─► │ Predictions │ │
│ │0.7, 0.2, 0.1│ (KL Div) │ │ │
│ └─────────────┘ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Why Knowledge Distillation?
The Dark Knowledge Hypothesis
Hard labels (one-hot vectors) discard valuable information about class relationships. Soft probability distributions from teachers encode:
- Similarity between classes: P(tiger|cat_image) > P(car|cat_image)
- Uncertainty: How confident is the model?
- Edge cases: Which samples are ambiguous?
┌─────────────────────────────────────────────────────────────────────────────┐
│ Hard Labels vs Soft Labels │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Hard Labels (One-Hot) Soft Labels (T=4) │
│ ───────────────────── ──────────────────── │
│ │
│ ┌──────────────────┐ ┌──────────────────┐ │
│ │ Cat: 1.0 │ │ Cat: 0.65 │ │
│ │ Tiger: 0.0 │ │ Tiger: 0.25 │ ──► "Cat looks │
│ │ Dog: 0.0 │ │ Dog: 0.10 │ like tiger!" │
│ └──────────────────┘ └──────────────────┘ │
│ │
│ Information: Information: │
│ • Only correct class • Class similarities │
│ • No relationships • Model uncertainty │
│ • Discards nuance • Edge case handling │
│ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ Dark Knowledge │ │
│ │ (Hidden in soft │ │
│ │ distributions) │ │
│ └──────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Compression Benefits
| Model | Parameters | Latency | Accuracy |
|---|---|---|---|
| BERT-Large | 340M | 100ms | 93.0% |
| BERT-Base | 110M | 35ms | 91.5% |
| DistilBERT | 66M | 20ms | 90.5% |
| TinyBERT | 14M | 8ms | 87.0% |
Project Setup
# Create project directory
mkdir knowledge-distillation && cd knowledge-distillation
# Create virtual environment
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install torch transformers datasets accelerate
pip install wandb evaluate scikit-learnProject Structure
knowledge-distillation/
├── models/
│ ├── teacher.py # Teacher model wrapper
│ ├── student.py # Student model definition
│ └── distiller.py # Distillation logic
├── losses/
│ ├── soft_target.py # Soft target loss
│ ├── feature.py # Feature distillation
│ └── attention.py # Attention distillation
├── training/
│ ├── trainer.py # Training loop
│ └── evaluation.py # Evaluation metrics
├── scripts/
│ ├── train.py # Training script
│ └── benchmark.py # Benchmarking
└── requirements.txtTeacher and Student Models
Teacher Model Wrapper
# models/teacher.py
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import Dict, Any, Optional
class Teacher(nn.Module):
"""Teacher model wrapper for distillation."""
def __init__(
self,
model_name: str = "bert-large-uncased",
num_labels: int = 2,
output_hidden_states: bool = True,
output_attentions: bool = True,
):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=num_labels,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Freeze teacher weights
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Forward pass returning all outputs."""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
return {
"logits": outputs.logits,
"hidden_states": outputs.hidden_states, # Tuple of layer outputs
"attentions": outputs.attentions, # Tuple of attention weights
"loss": outputs.loss if labels is not None else None,
}
def get_soft_labels(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
temperature: float = 4.0,
) -> torch.Tensor:
"""Get soft probability distribution."""
outputs = self.forward(input_ids, attention_mask)
logits = outputs["logits"]
# Apply temperature scaling
soft_labels = torch.softmax(logits / temperature, dim=-1)
return soft_labels
class PretrainedTeacher(Teacher):
"""Load a pretrained teacher from checkpoint."""
def __init__(self, checkpoint_path: str, num_labels: int = 2):
# Initialize base model
super().__init__(num_labels=num_labels)
# Load fine-tuned weights
state_dict = torch.load(checkpoint_path, map_location="cpu")
self.model.load_state_dict(state_dict)
# Freeze
for param in self.model.parameters():
param.requires_grad = False
self.model.eval()Student Model Definition
# models/student.py
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
from typing import Dict, Any, Optional
class Student(nn.Module):
"""Smaller student model for distillation."""
def __init__(
self,
model_name: str = "distilbert-base-uncased",
num_labels: int = 2,
hidden_size: int = 768,
output_hidden_states: bool = True,
output_attentions: bool = True,
):
super().__init__()
self.config = AutoConfig.from_pretrained(
model_name,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
)
self.encoder = AutoModel.from_pretrained(model_name, config=self.config)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
# Feature projection layers (if teacher has different hidden size)
self.hidden_size = hidden_size
if self.config.hidden_size != hidden_size:
self.feature_projector = nn.Linear(
self.config.hidden_size, hidden_size
)
else:
self.feature_projector = nn.Identity()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Forward pass."""
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
# Get [CLS] token representation
pooled_output = outputs.last_hidden_state[:, 0, :]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {
"logits": logits,
"hidden_states": outputs.hidden_states,
"attentions": outputs.attentions,
"loss": loss,
}
def get_projected_hidden_states(
self,
hidden_states: tuple,
) -> tuple:
"""Project hidden states to match teacher dimension."""
return tuple(
self.feature_projector(h) for h in hidden_states
)
class TinyStudent(nn.Module):
"""Very small student for aggressive compression."""
def __init__(
self,
vocab_size: int = 30522,
hidden_size: int = 256,
num_layers: int = 4,
num_heads: int = 4,
intermediate_size: int = 512,
num_labels: int = 2,
max_position: int = 512,
):
super().__init__()
# Embeddings
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position, hidden_size)
self.layer_norm = nn.LayerNorm(hidden_size)
self.dropout = nn.Dropout(0.1)
# Transformer layers
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=num_heads,
dim_feedforward=intermediate_size,
dropout=0.1,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)
# Classification head
self.classifier = nn.Linear(hidden_size, num_labels)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Forward pass."""
batch_size, seq_len = input_ids.shape
# Embeddings
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
embeddings = self.word_embeddings(input_ids) + self.position_embeddings(positions)
embeddings = self.layer_norm(embeddings)
embeddings = self.dropout(embeddings)
# Attention mask for transformer
src_key_padding_mask = (attention_mask == 0)
# Encode
hidden_states = self.encoder(embeddings, src_key_padding_mask=src_key_padding_mask)
# Classify using [CLS] token
pooled = hidden_states[:, 0, :]
logits = self.classifier(pooled)
loss = None
if labels is not None:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
return {
"logits": logits,
"hidden_states": (hidden_states,), # Simplified
"loss": loss,
}Distillation Losses
Soft Target Loss (Response-Based)
# losses/soft_target.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftTargetLoss(nn.Module):
"""
Soft target distillation loss.
The student learns from soft probability distributions
produced by the teacher at temperature T.
"""
def __init__(self, temperature: float = 4.0):
super().__init__()
self.temperature = temperature
def forward(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
) -> torch.Tensor:
"""
KL divergence between teacher and student soft predictions.
Args:
student_logits: Student model logits [batch, num_classes]
teacher_logits: Teacher model logits [batch, num_classes]
"""
# Apply temperature scaling
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
# KL divergence
loss = F.kl_div(
student_soft,
teacher_soft,
reduction="batchmean",
)
# Scale by T^2 (Hinton et al. recommendation)
loss = loss * (self.temperature ** 2)
return loss
class DistillationLoss(nn.Module):
"""
Combined distillation loss.
Loss = alpha * CE(student, labels) + (1-alpha) * KL(student, teacher)
"""
def __init__(
self,
temperature: float = 4.0,
alpha: float = 0.5,
):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.soft_loss = SoftTargetLoss(temperature)
self.hard_loss = nn.CrossEntropyLoss()
def forward(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""Combined loss computation."""
# Hard label loss (CE)
hard_loss = self.hard_loss(student_logits, labels)
# Soft target loss (KL)
soft_loss = self.soft_loss(student_logits, teacher_logits)
# Combined loss
total_loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
return {
"loss": total_loss,
"hard_loss": hard_loss,
"soft_loss": soft_loss,
}Understanding the Loss Components:
┌─────────────────────────────────────────────────────────────────────────────┐
│ DISTILLATION LOSS FORMULA │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Total Loss = α × Hard Loss + (1-α) × Soft Loss │
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Hard Loss (Cross-Entropy): │ │
│ │ CE(student_logits, ground_truth_labels) │ │
│ │ │ │
│ │ Purpose: Learn correct classification │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Soft Loss (KL Divergence): │ │
│ │ T² × KL(softmax(student/T) || softmax(teacher/T)) │ │
│ │ │ │
│ │ Purpose: Learn teacher's knowledge about class relationships │ │
│ │ │ │
│ │ Why T² scaling? At high T, gradients become T² smaller. │ │
│ │ Multiplying by T² restores gradient magnitude. │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Typical Values:
| Parameter | Typical Value | Effect |
|---|---|---|
| α = 0.5 | Balanced | Equal weight to labels and teacher |
| α = 0.1 | Teacher-heavy | Trust teacher more than labels |
| T = 4 | Standard | Good balance of soft/hard info |
| T = 20 | Very soft | Maximum knowledge transfer |
Feature-Based Distillation
# losses/feature.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class HiddenStateLoss(nn.Module):
"""
Feature-based distillation on hidden states.
Matches intermediate layer representations between
teacher and student.
"""
def __init__(
self,
student_layers: Tuple[int] = (2, 4, 6),
teacher_layers: Tuple[int] = (4, 8, 12),
normalize: bool = True,
):
super().__init__()
self.student_layers = student_layers
self.teacher_layers = teacher_layers
self.normalize = normalize
assert len(student_layers) == len(teacher_layers)
def forward(
self,
student_hidden: Tuple[torch.Tensor],
teacher_hidden: Tuple[torch.Tensor],
) -> torch.Tensor:
"""
MSE loss between selected layer hidden states.
Args:
student_hidden: Tuple of student hidden states
teacher_hidden: Tuple of teacher hidden states
"""
total_loss = 0.0
for s_layer, t_layer in zip(self.student_layers, self.teacher_layers):
s_hidden = student_hidden[s_layer]
t_hidden = teacher_hidden[t_layer]
# Normalize if requested
if self.normalize:
s_hidden = F.normalize(s_hidden, dim=-1)
t_hidden = F.normalize(t_hidden, dim=-1)
# MSE loss
loss = F.mse_loss(s_hidden, t_hidden)
total_loss += loss
return total_loss / len(self.student_layers)
class CosineEmbeddingLoss(nn.Module):
"""
Cosine similarity loss for hidden states.
More robust to scale differences between teacher and student.
"""
def __init__(self, margin: float = 0.0):
super().__init__()
self.margin = margin
def forward(
self,
student_hidden: torch.Tensor,
teacher_hidden: torch.Tensor,
) -> torch.Tensor:
"""
Cosine similarity loss.
Args:
student_hidden: Student hidden states [batch, seq, hidden]
teacher_hidden: Teacher hidden states [batch, seq, hidden]
"""
# Flatten sequence dimension
s_flat = student_hidden.view(-1, student_hidden.size(-1))
t_flat = teacher_hidden.view(-1, teacher_hidden.size(-1))
# Target: maximize similarity (target = 1)
target = torch.ones(s_flat.size(0), device=s_flat.device)
loss = F.cosine_embedding_loss(s_flat, t_flat, target, margin=self.margin)
return loss
class PKDLoss(nn.Module):
"""
Patient Knowledge Distillation loss.
Distills from multiple intermediate layers with patience.
"""
def __init__(
self,
num_student_layers: int = 6,
num_teacher_layers: int = 12,
):
super().__init__()
self.num_student_layers = num_student_layers
self.num_teacher_layers = num_teacher_layers
# Create layer mapping
self.layer_mapping = self._create_mapping()
def _create_mapping(self) -> list:
"""Create uniform layer mapping."""
skip = self.num_teacher_layers // self.num_student_layers
return [i * skip for i in range(self.num_student_layers)]
def forward(
self,
student_hidden: Tuple[torch.Tensor],
teacher_hidden: Tuple[torch.Tensor],
) -> torch.Tensor:
"""PKD loss computation."""
total_loss = 0.0
for s_idx, t_idx in enumerate(self.layer_mapping):
if s_idx >= len(student_hidden) or t_idx >= len(teacher_hidden):
continue
s_hidden = student_hidden[s_idx]
t_hidden = teacher_hidden[t_idx]
# Normalize hidden states
s_norm = F.normalize(s_hidden, dim=-1)
t_norm = F.normalize(t_hidden, dim=-1)
# MSE on normalized representations
loss = F.mse_loss(s_norm, t_norm)
total_loss += loss
return total_loss / len(self.layer_mapping)Attention Distillation
# losses/attention.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class AttentionDistillationLoss(nn.Module):
"""
Attention-based knowledge distillation.
Transfers attention patterns from teacher to student.
"""
def __init__(
self,
student_layers: Tuple[int] = (1, 3, 5),
teacher_layers: Tuple[int] = (3, 7, 11),
):
super().__init__()
self.student_layers = student_layers
self.teacher_layers = teacher_layers
def forward(
self,
student_attentions: Tuple[torch.Tensor],
teacher_attentions: Tuple[torch.Tensor],
) -> torch.Tensor:
"""
KL divergence between attention distributions.
Args:
student_attentions: Tuple of student attention weights
teacher_attentions: Tuple of teacher attention weights
"""
total_loss = 0.0
num_pairs = 0
for s_layer, t_layer in zip(self.student_layers, self.teacher_layers):
if s_layer >= len(student_attentions):
continue
if t_layer >= len(teacher_attentions):
continue
s_attn = student_attentions[s_layer] # [batch, heads, seq, seq]
t_attn = teacher_attentions[t_layer]
# Handle different number of attention heads
if s_attn.size(1) != t_attn.size(1):
# Average teacher heads to match student
t_attn = self._align_heads(t_attn, s_attn.size(1))
# KL divergence on attention distributions
s_attn_log = torch.log(s_attn + 1e-10)
loss = F.kl_div(s_attn_log, t_attn, reduction="batchmean")
total_loss += loss
num_pairs += 1
return total_loss / max(num_pairs, 1)
def _align_heads(
self,
attention: torch.Tensor,
target_heads: int,
) -> torch.Tensor:
"""Align number of attention heads."""
batch, heads, seq, _ = attention.shape
if heads > target_heads:
# Group and average heads
group_size = heads // target_heads
attention = attention.view(batch, target_heads, group_size, seq, seq)
attention = attention.mean(dim=2)
else:
# Repeat heads
repeat = target_heads // heads
attention = attention.repeat(1, repeat, 1, 1)
return attention
class TinyBERTLoss(nn.Module):
"""
TinyBERT distillation loss.
Combines embedding, hidden state, attention, and prediction losses.
"""
def __init__(
self,
temperature: float = 1.0,
alpha_embedding: float = 1.0,
alpha_hidden: float = 1.0,
alpha_attention: float = 1.0,
alpha_prediction: float = 1.0,
):
super().__init__()
self.temperature = temperature
self.alpha_embedding = alpha_embedding
self.alpha_hidden = alpha_hidden
self.alpha_attention = alpha_attention
self.alpha_prediction = alpha_prediction
self.mse_loss = nn.MSELoss()
def forward(
self,
student_outputs: dict,
teacher_outputs: dict,
labels: torch.Tensor = None,
) -> dict:
"""Compute TinyBERT loss."""
losses = {}
# Embedding layer loss
if "embeddings" in student_outputs and "embeddings" in teacher_outputs:
losses["embedding_loss"] = self.mse_loss(
student_outputs["embeddings"],
teacher_outputs["embeddings"],
) * self.alpha_embedding
# Hidden state loss
if "hidden_states" in student_outputs and "hidden_states" in teacher_outputs:
hidden_loss = self._hidden_state_loss(
student_outputs["hidden_states"],
teacher_outputs["hidden_states"],
)
losses["hidden_loss"] = hidden_loss * self.alpha_hidden
# Attention loss
if "attentions" in student_outputs and "attentions" in teacher_outputs:
attn_loss = self._attention_loss(
student_outputs["attentions"],
teacher_outputs["attentions"],
)
losses["attention_loss"] = attn_loss * self.alpha_attention
# Prediction layer loss (soft targets)
if "logits" in student_outputs and "logits" in teacher_outputs:
pred_loss = self._prediction_loss(
student_outputs["logits"],
teacher_outputs["logits"],
)
losses["prediction_loss"] = pred_loss * self.alpha_prediction
# Total loss
losses["loss"] = sum(losses.values())
return losses
def _hidden_state_loss(
self,
student_hidden: tuple,
teacher_hidden: tuple,
) -> torch.Tensor:
"""Hidden state matching loss."""
# Map student layers to teacher layers
s_layers = len(student_hidden)
t_layers = len(teacher_hidden)
mapping = [int(i * t_layers / s_layers) for i in range(s_layers)]
total_loss = 0.0
for s_idx, t_idx in enumerate(mapping):
s_h = student_hidden[s_idx]
t_h = teacher_hidden[min(t_idx, t_layers - 1)]
total_loss += self.mse_loss(s_h, t_h)
return total_loss / s_layers
def _attention_loss(
self,
student_attn: tuple,
teacher_attn: tuple,
) -> torch.Tensor:
"""Attention distribution matching loss."""
s_layers = len(student_attn)
t_layers = len(teacher_attn)
mapping = [int(i * t_layers / s_layers) for i in range(s_layers)]
total_loss = 0.0
for s_idx, t_idx in enumerate(mapping):
s_a = student_attn[s_idx]
t_a = teacher_attn[min(t_idx, t_layers - 1)]
# Average over heads if different
if s_a.size(1) != t_a.size(1):
t_a = t_a.mean(dim=1, keepdim=True).expand_as(s_a)
total_loss += self.mse_loss(s_a, t_a)
return total_loss / s_layers
def _prediction_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
) -> torch.Tensor:
"""Soft target prediction loss."""
s_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
t_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
loss = F.kl_div(s_soft, t_soft, reduction="batchmean")
return loss * (self.temperature ** 2)Distillation Trainer
# training/trainer.py
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from typing import Dict, Any, Optional
import wandb
from tqdm import tqdm
from dataclasses import dataclass
from models.teacher import Teacher
from models.student import Student
from losses.soft_target import DistillationLoss
from losses.feature import HiddenStateLoss
from losses.attention import AttentionDistillationLoss
@dataclass
class DistillationConfig:
"""Configuration for distillation training."""
# Loss weights
alpha: float = 0.5 # Hard vs soft label weight
temperature: float = 4.0 # Softmax temperature
feature_weight: float = 0.1 # Feature distillation weight
attention_weight: float = 0.1 # Attention distillation weight
# Training
learning_rate: float = 5e-5
batch_size: int = 32
num_epochs: int = 5
warmup_ratio: float = 0.1
weight_decay: float = 0.01
# Output
output_dir: str = "./outputs"
class DistillationTrainer:
"""Trainer for knowledge distillation."""
def __init__(
self,
teacher: Teacher,
student: Student,
train_dataset,
eval_dataset,
config: DistillationConfig,
tokenizer: AutoTokenizer,
device: str = "cuda",
):
self.teacher = teacher.to(device)
self.student = student.to(device)
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.config = config
self.tokenizer = tokenizer
self.device = device
# Initialize losses
self.distill_loss = DistillationLoss(
temperature=config.temperature,
alpha=config.alpha,
)
self.feature_loss = HiddenStateLoss()
self.attention_loss = AttentionDistillationLoss()
# Ensure teacher is in eval mode
self.teacher.eval()
def train(self) -> Dict[str, Any]:
"""Run distillation training."""
# Create dataloaders
train_loader = DataLoader(
self.train_dataset,
batch_size=self.config.batch_size,
shuffle=True,
num_workers=4,
)
eval_loader = DataLoader(
self.eval_dataset,
batch_size=self.config.batch_size,
shuffle=False,
num_workers=4,
)
# Setup optimizer (only student parameters)
optimizer = torch.optim.AdamW(
self.student.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
)
# Setup scheduler
total_steps = len(train_loader) * self.config.num_epochs
warmup_steps = int(total_steps * self.config.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
# Training loop
best_eval_acc = 0.0
global_step = 0
for epoch in range(self.config.num_epochs):
self.student.train()
epoch_losses = {"total": 0, "hard": 0, "soft": 0, "feature": 0, "attention": 0}
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
for batch in pbar:
# Move to device
batch = {k: v.to(self.device) for k, v in batch.items()}
# Teacher forward (no grad)
with torch.no_grad():
teacher_outputs = self.teacher(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
# Student forward
student_outputs = self.student(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
)
# Compute losses
loss_dict = self._compute_loss(
student_outputs,
teacher_outputs,
batch["labels"],
)
loss = loss_dict["total_loss"]
# Backward
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0)
optimizer.step()
scheduler.step()
# Update metrics
for k, v in loss_dict.items():
if k.endswith("_loss"):
key = k.replace("_loss", "")
epoch_losses[key] += v.item()
global_step += 1
pbar.set_postfix({"loss": loss.item()})
# Log to wandb
if global_step % 100 == 0:
wandb.log({
f"train/{k}": v for k, v in loss_dict.items()
}, step=global_step)
# Evaluation
eval_metrics = self._evaluate(eval_loader)
print(f"Epoch {epoch + 1} - Eval Accuracy: {eval_metrics['accuracy']:.4f}")
wandb.log({
"eval/accuracy": eval_metrics["accuracy"],
"eval/loss": eval_metrics["loss"],
"epoch": epoch + 1,
}, step=global_step)
# Save best model
if eval_metrics["accuracy"] > best_eval_acc:
best_eval_acc = eval_metrics["accuracy"]
self._save_model(f"{self.config.output_dir}/best")
# Save final model
self._save_model(f"{self.config.output_dir}/final")
return {
"best_accuracy": best_eval_acc,
"final_step": global_step,
}
def _compute_loss(
self,
student_outputs: dict,
teacher_outputs: dict,
labels: torch.Tensor,
) -> dict:
"""Compute combined distillation loss."""
# Response-based loss (soft targets + hard labels)
distill_losses = self.distill_loss(
student_outputs["logits"],
teacher_outputs["logits"],
labels,
)
total_loss = distill_losses["loss"]
losses = {
"hard_loss": distill_losses["hard_loss"],
"soft_loss": distill_losses["soft_loss"],
}
# Feature-based loss
if self.config.feature_weight > 0:
if student_outputs.get("hidden_states") and teacher_outputs.get("hidden_states"):
feat_loss = self.feature_loss(
student_outputs["hidden_states"],
teacher_outputs["hidden_states"],
)
total_loss += self.config.feature_weight * feat_loss
losses["feature_loss"] = feat_loss
# Attention-based loss
if self.config.attention_weight > 0:
if student_outputs.get("attentions") and teacher_outputs.get("attentions"):
attn_loss = self.attention_loss(
student_outputs["attentions"],
teacher_outputs["attentions"],
)
total_loss += self.config.attention_weight * attn_loss
losses["attention_loss"] = attn_loss
losses["total_loss"] = total_loss
return losses
def _evaluate(self, eval_loader: DataLoader) -> dict:
"""Evaluate student model."""
self.student.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in eval_loader:
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.student(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
)
total_loss += outputs["loss"].item()
predictions = outputs["logits"].argmax(dim=-1)
correct += (predictions == batch["labels"]).sum().item()
total += batch["labels"].size(0)
return {
"loss": total_loss / len(eval_loader),
"accuracy": correct / total,
}
def _save_model(self, path: str):
"""Save student model."""
torch.save({
"model_state_dict": self.student.state_dict(),
"config": self.config,
}, f"{path}/model.pt")
self.tokenizer.save_pretrained(path)Training Script
# scripts/train.py
import argparse
import wandb
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from models.teacher import Teacher
from models.student import Student
from training.trainer import DistillationTrainer, DistillationConfig
def tokenize_function(examples, tokenizer, max_length=128):
"""Tokenize examples."""
return tokenizer(
examples["text"],
truncation=True,
max_length=max_length,
padding="max_length",
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--teacher-model", default="bert-base-uncased")
parser.add_argument("--student-model", default="distilbert-base-uncased")
parser.add_argument("--dataset", default="imdb")
parser.add_argument("--output-dir", default="./outputs")
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--learning-rate", type=float, default=5e-5)
parser.add_argument("--temperature", type=float, default=4.0)
parser.add_argument("--alpha", type=float, default=0.5)
parser.add_argument("--feature-weight", type=float, default=0.1)
parser.add_argument("--attention-weight", type=float, default=0.1)
parser.add_argument("--wandb-project", default="knowledge-distillation")
args = parser.parse_args()
# Initialize wandb
wandb.init(project=args.wandb_project, config=vars(args))
# Load dataset
print(f"Loading dataset: {args.dataset}")
dataset = load_dataset(args.dataset)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.teacher_model)
# Tokenize dataset
tokenized = dataset.map(
lambda x: tokenize_function(x, tokenizer),
batched=True,
remove_columns=["text"],
)
tokenized.set_format("torch")
# Create models
print("Loading teacher model...")
teacher = Teacher(
model_name=args.teacher_model,
num_labels=2, # Binary classification for IMDB
)
print("Creating student model...")
student = Student(
model_name=args.student_model,
num_labels=2,
)
# Print parameter comparison
teacher_params = sum(p.numel() for p in teacher.parameters())
student_params = sum(p.numel() for p in student.parameters())
print(f"Teacher parameters: {teacher_params:,}")
print(f"Student parameters: {student_params:,}")
print(f"Compression ratio: {teacher_params / student_params:.2f}x")
# Create config
config = DistillationConfig(
alpha=args.alpha,
temperature=args.temperature,
feature_weight=args.feature_weight,
attention_weight=args.attention_weight,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
num_epochs=args.epochs,
output_dir=args.output_dir,
)
# Train
trainer = DistillationTrainer(
teacher=teacher,
student=student,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
config=config,
tokenizer=tokenizer,
)
results = trainer.train()
print(f"Training complete! Best accuracy: {results['best_accuracy']:.4f}")
if __name__ == "__main__":
main()Evaluation and Benchmarking
# scripts/benchmark.py
import torch
import time
import argparse
from transformers import AutoTokenizer
from typing import Dict
from models.teacher import Teacher
from models.student import Student
def benchmark_model(
model,
tokenizer,
num_samples: int = 1000,
batch_size: int = 32,
max_length: int = 128,
device: str = "cuda",
) -> Dict[str, float]:
"""Benchmark model latency and throughput."""
model = model.to(device)
model.eval()
# Create dummy input
dummy_text = "This is a sample text for benchmarking." * 5
inputs = tokenizer(
[dummy_text] * batch_size,
truncation=True,
max_length=max_length,
padding="max_length",
return_tensors="pt",
).to(device)
# Warmup
for _ in range(10):
with torch.no_grad():
_ = model(**inputs)
# Benchmark
torch.cuda.synchronize()
start_time = time.perf_counter()
num_batches = num_samples // batch_size
for _ in range(num_batches):
with torch.no_grad():
_ = model(**inputs)
torch.cuda.synchronize()
total_time = time.perf_counter() - start_time
# Calculate metrics
samples_per_second = num_samples / total_time
latency_ms = (total_time / num_batches) * 1000
latency_per_sample_ms = (total_time / num_samples) * 1000
return {
"samples_per_second": samples_per_second,
"batch_latency_ms": latency_ms,
"sample_latency_ms": latency_per_sample_ms,
"total_time_s": total_time,
}
def count_parameters(model) -> Dict[str, int]:
"""Count model parameters."""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
"total_parameters": total,
"trainable_parameters": trainable,
}
def get_model_size_mb(model) -> float:
"""Get model size in MB."""
param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
return (param_size + buffer_size) / (1024 * 1024)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--teacher-model", default="bert-base-uncased")
parser.add_argument("--student-checkpoint", required=True)
parser.add_argument("--num-samples", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=32)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.teacher_model)
# Load models
print("Loading teacher...")
teacher = Teacher(model_name=args.teacher_model, num_labels=2)
print("Loading student...")
student = Student(model_name="distilbert-base-uncased", num_labels=2)
checkpoint = torch.load(f"{args.student_checkpoint}/model.pt")
student.load_state_dict(checkpoint["model_state_dict"])
# Compare parameters
teacher_params = count_parameters(teacher)
student_params = count_parameters(student)
print("\n=== Parameter Comparison ===")
print(f"Teacher: {teacher_params['total_parameters']:,}")
print(f"Student: {student_params['total_parameters']:,}")
print(f"Compression: {teacher_params['total_parameters'] / student_params['total_parameters']:.2f}x")
# Compare model size
teacher_size = get_model_size_mb(teacher)
student_size = get_model_size_mb(student)
print("\n=== Model Size ===")
print(f"Teacher: {teacher_size:.2f} MB")
print(f"Student: {student_size:.2f} MB")
print(f"Size reduction: {teacher_size / student_size:.2f}x")
# Benchmark latency
print("\n=== Latency Benchmark ===")
device = "cuda" if torch.cuda.is_available() else "cpu"
teacher_bench = benchmark_model(teacher, tokenizer, args.num_samples, args.batch_size, device=device)
student_bench = benchmark_model(student, tokenizer, args.num_samples, args.batch_size, device=device)
print(f"Teacher: {teacher_bench['sample_latency_ms']:.2f} ms/sample")
print(f"Student: {student_bench['sample_latency_ms']:.2f} ms/sample")
print(f"Speedup: {teacher_bench['sample_latency_ms'] / student_bench['sample_latency_ms']:.2f}x")
print(f"\nTeacher throughput: {teacher_bench['samples_per_second']:.1f} samples/sec")
print(f"Student throughput: {student_bench['samples_per_second']:.1f} samples/sec")
if __name__ == "__main__":
main()Temperature Analysis
# Visualize effect of temperature
import torch
import matplotlib.pyplot as plt
import numpy as np
def visualize_temperature_effect():
"""Show how temperature affects soft labels."""
# Example logits
logits = torch.tensor([2.0, 1.0, 0.5, 0.1])
temperatures = [1.0, 2.0, 4.0, 8.0]
fig, axes = plt.subplots(1, len(temperatures), figsize=(15, 4))
for ax, T in zip(axes, temperatures):
probs = torch.softmax(logits / T, dim=0).numpy()
ax.bar(["A", "B", "C", "D"], probs)
ax.set_title(f"T = {T}")
ax.set_ylim(0, 1)
ax.set_ylabel("Probability")
plt.suptitle("Effect of Temperature on Soft Labels")
plt.tight_layout()
plt.savefig("temperature_effect.png")
plt.show()
# Run visualization
visualize_temperature_effect()Advanced Techniques
Progressive Distillation
# Progressive layer-by-layer distillation
class ProgressiveDistillation:
"""Train student layer by layer."""
def __init__(self, teacher, student, num_stages: int = 3):
self.teacher = teacher
self.student = student
self.num_stages = num_stages
def train_stage(self, stage: int, train_loader, epochs: int):
"""Train specific layers in a stage."""
# Freeze earlier layers
num_student_layers = len(self.student.encoder.layer)
layers_per_stage = num_student_layers // self.num_stages
for i, layer in enumerate(self.student.encoder.layer):
if i < stage * layers_per_stage:
for param in layer.parameters():
param.requires_grad = False
else:
for param in layer.parameters():
param.requires_grad = True
# Train this stage
# ... training loop ...Data Augmentation for Distillation
# Augment training data with teacher predictions
class DataAugmentedDistillation:
"""Use teacher to label unlabeled data."""
def __init__(self, teacher, tokenizer):
self.teacher = teacher
self.tokenizer = tokenizer
@torch.no_grad()
def label_unlabeled_data(
self,
texts: list,
batch_size: int = 32,
threshold: float = 0.9,
) -> list:
"""Label high-confidence samples."""
labeled = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
inputs = self.tokenizer(
batch_texts,
truncation=True,
padding=True,
return_tensors="pt",
)
outputs = self.teacher(**inputs)
probs = torch.softmax(outputs["logits"], dim=-1)
confidence, predictions = probs.max(dim=-1)
# Keep high-confidence predictions
for text, conf, pred in zip(batch_texts, confidence, predictions):
if conf > threshold:
labeled.append({
"text": text,
"label": pred.item(),
"confidence": conf.item(),
})
return labeledKey Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Teacher Model | Large, accurate model (frozen during distillation) | Provides soft labels containing "dark knowledge" |
| Student Model | Smaller model being trained | Learns to mimic teacher at fraction of the size |
| Soft Labels | Probability distributions from softmax (not one-hot) | Encode class relationships and model uncertainty |
| Temperature (T) | Scaling factor applied before softmax | Higher T → softer distributions, more knowledge transfer |
| Dark Knowledge | Information in soft labels about class similarities | "Cat looks like tiger" vs just "this is a cat" |
| KL Divergence Loss | Measures difference between teacher and student distributions | Forces student to match teacher's soft predictions |
| Alpha (α) | Weight balancing hard labels vs soft distillation | α=0.5 means equal weight to ground truth and teacher |
| Feature Distillation | Matching intermediate layer representations | Transfers learned features, not just final predictions |
| Attention Distillation | Matching attention patterns between teacher/student | Transfers "where to look" behavior |
| T² Scaling | Multiply KL loss by temperature² | Compensates for gradient magnitude changes at high T |
Next Steps
After completing this project, consider:
- Quantization - Further compress distilled models
- LoRA Fine-tuning - Efficient adaptation
- Model Inference API - Deploy distilled models