Deep LearningAdvanced
Distributed Training
Scale training across multiple GPUs with PyTorch DDP, FSDP, and DeepSpeed
Distributed Training
Train large models efficiently across multiple GPUs and nodes using PyTorch's distributed training capabilities.
TL;DR
Distributed training splits work across GPUs to train faster or fit larger models. DDP (DistributedDataParallel) replicates model on each GPU with gradient sync via AllReduce. FSDP/DeepSpeed ZeRO shard model parameters, gradients, and optimizer states across GPUs for models that don't fit on one GPU. Key formula: effective_batch_size = batch_per_gpu × gradient_accum × num_gpus.
Overview
| Aspect | Details |
|---|---|
| Difficulty | Advanced |
| Time | 5 days |
| Code | ~900 lines |
| Prerequisites | PyTorch, multi-GPU access |
What You'll Build
A scalable distributed training system that:
- Trains across multiple GPUs with DistributedDataParallel (DDP)
- Handles large models with Fully Sharded Data Parallel (FSDP)
- Integrates DeepSpeed for memory-efficient training
- Implements gradient checkpointing for memory savings
- Supports multi-node cluster training
┌─────────────────────────────────────────────────────────────────────────────┐
│ Multi-Node Distributed Training │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ │
│ │ Training Data │ │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────┐ │
│ │ Distributed Sampler │ │
│ │ (shards data by rank) │ │
│ └──┬─────┬─────┬─────┬────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Node 1 │ │ Node 2 │ │
│ │ ┌───────┐ ┌───────┐ │ │ ┌───────┐ ┌───────┐ │ │
│ │ │ GPU 0 │ │ GPU 1 │ │ │ │ GPU 2 │ │ GPU 3 │ │ │
│ │ │Rank 0 │ │Rank 1 │ │ │ │Rank 2 │ │Rank 3 │ │ │
│ │ │Shard 0│ │Shard 1│ │ │ │Shard 2│ │Shard 3│ │ │
│ │ └───┬───┘ └───┬───┘ │ │ └───┬───┘ └───┬───┘ │ │
│ └─────┼─────────┼─────┘ └─────┼─────────┼─────┘ │
│ │ │ │ │ │
│ └─────────┴───────────────┴─────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ AllReduce │ │
│ │ (sync gradients │ │
│ │ across GPUs) │ │
│ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Understanding Distributed Training
Why Distribute?
┌─────────────────────────────────────────────────────────────────────────────┐
│ Why Distributed Training? │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ │
│ │ Distributed │ │
│ │ Training │ │
│ └────────┬────────┘ │
│ ┌──────────────────────┼──────────────────────┐ │
│ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ Speed │ │ Scale │ │ Memory │ │Production│ │
│ ├──────────┤ ├──────────┤ ├──────────┤ ├──────────┤ │
│ │• Larger │ │• Train │ │• Model │ │• Fault │ │
│ │ batches │ │ larger │ │ parallel│ │ tolerant│ │
│ │• Parallel│ │ models │ │• Gradient│ │• Check- │ │
│ │ compute │ │• More │ │ sharding│ │ pointing│ │
│ │• Reduced │ │ data/ │ │• Optimizer│ │• Multi- │ │
│ │ wall │ │ step │ │ sharding│ │ node │ │
│ │ time │ │• Better │ │ │ │ clusters│ │
│ │ │ │ converg.│ │ │ │ │ │
│ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Parallelism Strategies
┌─────────────────────────────────────────────────────────────────────────────┐
│ Parallelism Strategies │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Data Parallelism │ │ Model Parallelism │ │ Fully Sharded │ │
│ │ (DDP) │ │ (Pipeline) │ │ (FSDP) │ │
│ ├─────────────────────┤ ├─────────────────────┤ ├─────────────────────┤ │
│ │ │ │ │ │ │ │
│ │ ┌───────┐ ┌───────┐│ │ ┌───────────┐ │ │ ┌───────┐ ┌───────┐ │ │
│ │ │Model │ │Model ││ │ │ Layers 1-6│ │ │ │Params │ │Params │ │ │
│ │ │Copy 1 │ │Copy 2 ││ │ │ GPU 0 │ │ │ │Shard 1│ │Shard 2│ │ │
│ │ └───┬───┘ └───┬───┘│ │ └─────┬─────┘ │ │ └───┬───┘ └───┬───┘ │ │
│ │ │ │ │ │ │ │ │ │ │ │ │
│ │ │ │ │ │ ▼ │ │ └────┬────┘ │ │
│ │ ┌───┴───┐ ┌───┴───┐│ │ ┌───────────┐ │ │ │ │ │
│ │ │ Data │ │ Data ││ │ │Layers 7-12│ │ │ ▼ │ │
│ │ │Batch 1│ │Batch 2││ │ │ GPU 1 │ │ │ ┌─────────────┐ │ │
│ │ └───────┘ └───────┘│ │ └───────────┘ │ │ │ Gather │ │ │
│ │ │ │ │ │ │ Compute │ │ │
│ │ Same model, diff │ │ Diff layers on │ │ │ Scatter │ │ │
│ │ data batches │ │ different GPUs │ │ └─────────────┘ │ │
│ └─────────────────────┘ └─────────────────────┘ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Project Setup
Environment Setup
# Create project directory
mkdir distributed-training && cd distributed-training
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install torch torchvision
pip install transformers datasets accelerate
pip install deepspeed
pip install tensorboard wandb
pip install pynvml psutilProject Structure
distributed-training/
├── configs/
│ ├── ddp_config.yaml
│ ├── fsdp_config.yaml
│ └── deepspeed_config.json
├── src/
│ ├── __init__.py
│ ├── ddp_trainer.py
│ ├── fsdp_trainer.py
│ ├── deepspeed_trainer.py
│ ├── utils.py
│ └── monitoring.py
├── scripts/
│ ├── launch_ddp.sh
│ ├── launch_fsdp.sh
│ └── launch_multinode.sh
├── train.py
└── requirements.txtRequirements
# requirements.txt
torch>=2.0.0
transformers>=4.35.0
datasets>=2.14.0
accelerate>=0.24.0
deepspeed>=0.12.0
tensorboard>=2.15.0
wandb>=0.16.0
pynvml>=11.5.0
psutil>=5.9.0
pyyaml>=6.0.0Part 1: PyTorch DDP (DistributedDataParallel)
Understanding DDP
┌─────────────────────────────────────────────────────────────────────────────┐
│ DDP Training Sequence │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ GPU 0 GPU 1 GPU 2 GPU 3 │
│ │ │ │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ FORWARD PASS (parallel) │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ Compute Compute Compute Compute │
│ loss_0 loss_1 loss_2 loss_3 │
│ │ │ │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ BACKWARD PASS (parallel) │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ Compute Compute Compute Compute │
│ grad_0 grad_1 grad_2 grad_3 │
│ │ │ │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ ALLREDUCE (synchronized) │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ │ │ │ │
│ └──────────────┴──────────────┴──────────────┘ │
│ │ │
│ ▼ │
│ avg_grad = sum(grads) / 4 │
│ │ │
│ ┌──────────────┬──────────────┬──────────────┐ │
│ │ │ │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ OPTIMIZER STEP (parallel) │ │ │
│ ═════╪══════════════╪══════════════╪══════════════╪═════ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ Update Update Update Update │
│ params params params params │
│ │
└─────────────────────────────────────────────────────────────────────────────┘DDP Configuration
# configs/ddp_config.yaml
training:
model_name: "bert-base-uncased"
batch_size_per_gpu: 16
gradient_accumulation_steps: 4
learning_rate: 2e-5
num_epochs: 3
warmup_ratio: 0.1
weight_decay: 0.01
max_grad_norm: 1.0
distributed:
backend: "nccl" # Use "gloo" for CPU or Windows
find_unused_parameters: false
gradient_as_bucket_view: true
static_graph: true
checkpointing:
save_every_n_steps: 1000
checkpoint_dir: "./checkpoints"
keep_last_n: 3
logging:
log_every_n_steps: 100
use_tensorboard: true
use_wandb: falseCore DDP Trainer Implementation
# src/ddp_trainer.py
"""
Distributed Data Parallel Trainer for PyTorch.
Handles multi-GPU training with gradient synchronization.
"""
import os
import time
from typing import Optional, Dict, Any
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from torch.cuda.amp import GradScaler, autocast
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup
)
from datasets import load_dataset
@dataclass
class DDPConfig:
"""Configuration for DDP training."""
model_name: str = "bert-base-uncased"
batch_size_per_gpu: int = 16
gradient_accumulation_steps: int = 4
learning_rate: float = 2e-5
num_epochs: int = 3
warmup_ratio: float = 0.1
weight_decay: float = 0.01
max_grad_norm: float = 1.0
backend: str = "nccl"
find_unused_parameters: bool = False
use_amp: bool = True
checkpoint_dir: str = "./checkpoints"
log_every_n_steps: int = 100
class DDPTrainer:
"""
Trainer for Distributed Data Parallel training.
Handles:
- Process group initialization
- Model wrapping with DDP
- Distributed data loading
- Gradient synchronization
- Mixed precision training
- Checkpointing
"""
def __init__(self, config: DDPConfig):
self.config = config
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.global_rank = int(os.environ.get("RANK", 0))
# Initialize process group
self._init_distributed()
# Set device
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
# Initialize components
self.model = None
self.optimizer = None
self.scheduler = None
self.scaler = GradScaler() if config.use_amp else None
# Tracking
self.global_step = 0
self.epoch = 0
def _init_distributed(self):
"""Initialize the distributed process group."""
if not dist.is_initialized():
dist.init_process_group(
backend=self.config.backend,
init_method="env://",
)
if self.is_main_process:
print(f"Initialized distributed training:")
print(f" World size: {self.world_size}")
print(f" Backend: {self.config.backend}")
@property
def is_main_process(self) -> bool:
"""Check if this is the main process (rank 0)."""
return self.global_rank == 0
def setup_model(self, num_labels: int = 2):
"""Initialize and wrap model with DDP."""
# Load model on correct device
self.model = AutoModelForSequenceClassification.from_pretrained(
self.config.model_name,
num_labels=num_labels
).to(self.device)
# Wrap with DDP
self.model = DDP(
self.model,
device_ids=[self.local_rank],
output_device=self.local_rank,
find_unused_parameters=self.config.find_unused_parameters,
gradient_as_bucket_view=True,
static_graph=True,
)
if self.is_main_process:
total_params = sum(p.numel() for p in self.model.parameters())
print(f"Model loaded: {total_params:,} parameters")
def setup_optimizer(self, num_training_steps: int):
"""Initialize optimizer and scheduler."""
# Separate parameters for weight decay
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in self.model.named_parameters()
if not any(nd in n for nd in no_decay)
],
"weight_decay": self.config.weight_decay,
},
{
"params": [
p for n, p in self.model.named_parameters()
if any(nd in n for nd in no_decay)
],
"weight_decay": 0.0,
},
]
self.optimizer = torch.optim.AdamW(
optimizer_grouped_parameters,
lr=self.config.learning_rate,
)
# Linear warmup scheduler
num_warmup_steps = int(num_training_steps * self.config.warmup_ratio)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
def get_dataloader(
self,
dataset,
tokenizer,
is_training: bool = True
) -> DataLoader:
"""Create distributed dataloader."""
def collate_fn(batch):
texts = [item["text"] for item in batch]
labels = torch.tensor([item["label"] for item in batch])
encodings = tokenizer(
texts,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
return {
"input_ids": encodings["input_ids"],
"attention_mask": encodings["attention_mask"],
"labels": labels,
}
# Distributed sampler ensures each process gets different data
sampler = DistributedSampler(
dataset,
num_replicas=self.world_size,
rank=self.global_rank,
shuffle=is_training,
)
dataloader = DataLoader(
dataset,
batch_size=self.config.batch_size_per_gpu,
sampler=sampler,
collate_fn=collate_fn,
num_workers=4,
pin_memory=True,
)
return dataloader
def train_step(self, batch: Dict[str, torch.Tensor]) -> float:
"""Execute single training step with gradient accumulation."""
# Move batch to device
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass with optional AMP
if self.config.use_amp:
with autocast():
outputs = self.model(**batch)
loss = outputs.loss / self.config.gradient_accumulation_steps
else:
outputs = self.model(**batch)
loss = outputs.loss / self.config.gradient_accumulation_steps
# Backward pass
if self.config.use_amp:
self.scaler.scale(loss).backward()
else:
loss.backward()
return loss.item() * self.config.gradient_accumulation_steps
def optimizer_step(self):
"""Execute optimizer step with gradient clipping."""
if self.config.use_amp:
# Unscale gradients for clipping
self.scaler.unscale_(self.optimizer)
# Gradient clipping
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
# Optimizer step
if self.config.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
self.global_step += 1
def train_epoch(self, dataloader: DataLoader) -> float:
"""Train for one epoch."""
self.model.train()
dataloader.sampler.set_epoch(self.epoch) # Important for shuffling
total_loss = 0.0
num_batches = 0
accumulation_loss = 0.0
for step, batch in enumerate(dataloader):
loss = self.train_step(batch)
accumulation_loss += loss
# Gradient accumulation
if (step + 1) % self.config.gradient_accumulation_steps == 0:
self.optimizer_step()
# Logging
if self.global_step % self.config.log_every_n_steps == 0:
if self.is_main_process:
print(
f"Step {self.global_step} | "
f"Loss: {accumulation_loss:.4f} | "
f"LR: {self.scheduler.get_last_lr()[0]:.2e}"
)
total_loss += accumulation_loss
accumulation_loss = 0.0
num_batches += 1
return total_loss / max(num_batches, 1)
@torch.no_grad()
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]:
"""Evaluate model on validation set."""
self.model.eval()
total_loss = 0.0
correct = 0
total = 0
for batch in dataloader:
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
total_loss += outputs.loss.item()
predictions = outputs.logits.argmax(dim=-1)
correct += (predictions == batch["labels"]).sum().item()
total += batch["labels"].size(0)
# Aggregate metrics across all processes
metrics = torch.tensor(
[total_loss, correct, total],
device=self.device
)
dist.all_reduce(metrics, op=dist.ReduceOp.SUM)
total_loss, correct, total = metrics.tolist()
return {
"eval_loss": total_loss / len(dataloader) / self.world_size,
"accuracy": correct / total,
}
def save_checkpoint(self, path: str):
"""Save checkpoint (only on main process)."""
if not self.is_main_process:
return
# Save model state (unwrap DDP)
model_state = self.model.module.state_dict()
checkpoint = {
"model_state_dict": model_state,
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"scaler_state_dict": self.scaler.state_dict() if self.scaler else None,
"global_step": self.global_step,
"epoch": self.epoch,
"config": self.config,
}
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(checkpoint, path)
print(f"Saved checkpoint to {path}")
def load_checkpoint(self, path: str):
"""Load checkpoint on all processes."""
# Load on CPU first, then move to device
checkpoint = torch.load(path, map_location="cpu")
# Load model state
self.model.module.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
if self.scaler and checkpoint["scaler_state_dict"]:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.global_step = checkpoint["global_step"]
self.epoch = checkpoint["epoch"]
if self.is_main_process:
print(f"Loaded checkpoint from {path}")
def cleanup(self):
"""Clean up distributed resources."""
dist.destroy_process_group()
def train_ddp(config: DDPConfig):
"""Main DDP training function."""
trainer = DDPTrainer(config)
# Load tokenizer and datasets
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
dataset = load_dataset("imdb")
# Setup model and optimizer
trainer.setup_model(num_labels=2)
train_dataloader = trainer.get_dataloader(
dataset["train"], tokenizer, is_training=True
)
eval_dataloader = trainer.get_dataloader(
dataset["test"], tokenizer, is_training=False
)
num_training_steps = (
len(train_dataloader) // config.gradient_accumulation_steps
* config.num_epochs
)
trainer.setup_optimizer(num_training_steps)
# Training loop
for epoch in range(config.num_epochs):
trainer.epoch = epoch
if trainer.is_main_process:
print(f"\n{'='*50}")
print(f"Epoch {epoch + 1}/{config.num_epochs}")
print(f"{'='*50}")
# Train
train_loss = trainer.train_epoch(train_dataloader)
# Evaluate
metrics = trainer.evaluate(eval_dataloader)
if trainer.is_main_process:
print(f"\nEpoch {epoch + 1} Results:")
print(f" Train Loss: {train_loss:.4f}")
print(f" Eval Loss: {metrics['eval_loss']:.4f}")
print(f" Accuracy: {metrics['accuracy']:.4f}")
# Save checkpoint
checkpoint_path = os.path.join(
config.checkpoint_dir,
f"checkpoint_epoch_{epoch + 1}.pt"
)
trainer.save_checkpoint(checkpoint_path)
trainer.cleanup()
return metrics
if __name__ == "__main__":
config = DDPConfig()
train_ddp(config)Launch Scripts for DDP
#!/bin/bash
# scripts/launch_ddp.sh
# Single node, multiple GPUs
NUM_GPUS=4
torchrun \
--nproc_per_node=$NUM_GPUS \
--master_port=29500 \
train.py \
--mode ddp \
--config configs/ddp_config.yaml#!/bin/bash
# scripts/launch_multinode.sh
# Multi-node training
# Run this on each node, changing NODE_RANK
MASTER_ADDR="192.168.1.100" # IP of node 0
MASTER_PORT=29500
NUM_NODES=2
NUM_GPUS_PER_NODE=4
NODE_RANK=$1 # Pass as argument: 0 for master, 1 for worker
torchrun \
--nproc_per_node=$NUM_GPUS_PER_NODE \
--nnodes=$NUM_NODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
train.py \
--mode ddp \
--config configs/ddp_config.yamlPart 2: Fully Sharded Data Parallel (FSDP)
Understanding FSDP
┌─────────────────────────────────────────────────────────────────────────────┐
│ DDP vs FSDP Memory Comparison │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Standard DDP (each GPU holds full model + optimizer) │
│ ──────────────────────────────────────────────────── │
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │
│ │ ┌───────────┐ │ │ ┌───────────┐ │ │ ┌───────────┐ │ │
│ │ │Full Model │ │ │ │Full Model │ │ │ │Full Model │ │ │
│ │ │ (100%) │ │ │ │ (100%) │ │ │ │ (100%) │ │ │
│ │ ├───────────┤ │ │ ├───────────┤ │ │ ├───────────┤ │ │
│ │ │Full Optim │ │ │ │Full Optim │ │ │ │Full Optim │ │ │
│ │ │ (100%) │ │ │ │ (100%) │ │ │ │ (100%) │ │ │
│ │ └───────────┘ │ │ └───────────┘ │ │ └───────────┘ │ │
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
│ │
│ Total Memory: 3x Model ────────────────────────▶ ✗ HIGH MEMORY │
│ │
│ FSDP (sharded across GPUs) │
│ ────────────────────────── │
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ GPU 0 │ │ GPU 1 │ │ GPU 2 │ │
│ │ ┌───────────┐ │ │ ┌───────────┐ │ │ ┌───────────┐ │ │
│ │ │Params 0-33│ │ │ │Params 34-66│ │ │ │Params 67-100│ │ │
│ │ │ (33%) │ │ │ │ (33%) │ │ │ │ (34%) │ │ │
│ │ ├───────────┤ │ │ ├───────────┤ │ │ ├───────────┤ │ │
│ │ │Optim 0-33 │ │ │ │Optim 34-66│ │ │ │Optim 67-100│ │ │
│ │ │ (33%) │ │ │ │ (33%) │ │ │ │ (34%) │ │ │
│ │ └───────────┘ │ │ └───────────┘ │ │ └───────────┘ │ │
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
│ │
│ Total Memory: ~1x Model ────────────────────────▶ ✓ LOW MEMORY │
│ │
└─────────────────────────────────────────────────────────────────────────────┘FSDP Configuration
# configs/fsdp_config.yaml
training:
model_name: "meta-llama/Llama-2-7b-hf"
batch_size_per_gpu: 2
gradient_accumulation_steps: 8
learning_rate: 1e-5
num_epochs: 1
max_length: 2048
fsdp:
sharding_strategy: "FULL_SHARD" # FULL_SHARD, SHARD_GRAD_OP, NO_SHARD
cpu_offload: false
backward_prefetch: "BACKWARD_PRE"
forward_prefetch: true
auto_wrap_policy: "transformer_based"
min_num_params: 100000000 # 100M params for wrapping
use_orig_params: true
sync_module_states: true
activation_checkpointing:
enabled: true
checkpoint_every_n_layers: 1
mixed_precision:
param_dtype: "bfloat16"
reduce_dtype: "float32"
buffer_dtype: "float32"FSDP Trainer Implementation
# src/fsdp_trainer.py
"""
Fully Sharded Data Parallel (FSDP) Trainer.
Enables training of large models that don't fit on a single GPU.
"""
import os
import functools
from typing import Optional, Dict, Any, Set
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
BackwardPrefetch,
ShardingStrategy,
CPUOffload,
StateDictType,
FullStateDictConfig,
ShardedStateDictConfig,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from torch.utils.data import DataLoader, DistributedSampler
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
@dataclass
class FSDPConfig:
"""Configuration for FSDP training."""
model_name: str = "meta-llama/Llama-2-7b-hf"
batch_size_per_gpu: int = 2
gradient_accumulation_steps: int = 8
learning_rate: float = 1e-5
num_epochs: int = 1
max_length: int = 2048
# FSDP settings
sharding_strategy: str = "FULL_SHARD"
cpu_offload: bool = False
backward_prefetch: str = "BACKWARD_PRE"
use_activation_checkpointing: bool = True
# Mixed precision
param_dtype: str = "bfloat16"
reduce_dtype: str = "float32"
checkpoint_dir: str = "./checkpoints"
class FSDPTrainer:
"""
Trainer for Fully Sharded Data Parallel training.
FSDP shards model parameters, gradients, and optimizer states
across GPUs, enabling training of models larger than GPU memory.
"""
SHARDING_STRATEGIES = {
"FULL_SHARD": ShardingStrategy.FULL_SHARD,
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
"NO_SHARD": ShardingStrategy.NO_SHARD,
"HYBRID_SHARD": ShardingStrategy.HYBRID_SHARD,
}
DTYPE_MAP = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
def __init__(self, config: FSDPConfig):
self.config = config
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.global_rank = int(os.environ.get("RANK", 0))
# Initialize distributed
self._init_distributed()
# Set device
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
self.model = None
self.optimizer = None
def _init_distributed(self):
"""Initialize distributed process group."""
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
@property
def is_main_process(self) -> bool:
return self.global_rank == 0
def _get_mixed_precision_policy(self) -> MixedPrecision:
"""Create mixed precision policy."""
return MixedPrecision(
param_dtype=self.DTYPE_MAP[self.config.param_dtype],
reduce_dtype=self.DTYPE_MAP[self.config.reduce_dtype],
buffer_dtype=self.DTYPE_MAP[self.config.reduce_dtype],
)
def _get_auto_wrap_policy(self, model_config):
"""Create auto wrap policy for transformer layers."""
# Get the decoder layer class for the model
# This wraps each transformer layer in FSDP
# For Llama models
transformer_layer_cls: Set[type] = {LlamaDecoderLayer}
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_layer_cls,
)
return auto_wrap_policy
def setup_model(self):
"""Initialize model with FSDP wrapping."""
if self.is_main_process:
print(f"Loading model: {self.config.model_name}")
# Load config first
model_config = AutoConfig.from_pretrained(self.config.model_name)
# Load model with empty weights (FSDP will handle initialization)
with torch.device("meta"):
model = AutoModelForCausalLM.from_config(model_config)
# Get FSDP configuration
sharding_strategy = self.SHARDING_STRATEGIES[self.config.sharding_strategy]
mixed_precision = self._get_mixed_precision_policy()
auto_wrap_policy = self._get_auto_wrap_policy(model_config)
# CPU offload configuration
cpu_offload = CPUOffload(offload_params=True) if self.config.cpu_offload else None
# Backward prefetch
backward_prefetch = {
"BACKWARD_PRE": BackwardPrefetch.BACKWARD_PRE,
"BACKWARD_POST": BackwardPrefetch.BACKWARD_POST,
}.get(self.config.backward_prefetch)
# Wrap model with FSDP
self.model = FSDP(
model,
sharding_strategy=sharding_strategy,
mixed_precision=mixed_precision,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=cpu_offload,
backward_prefetch=backward_prefetch,
device_id=self.local_rank,
use_orig_params=True,
sync_module_states=True,
param_init_fn=self._param_init_fn,
)
# Apply activation checkpointing
if self.config.use_activation_checkpointing:
self._apply_activation_checkpointing()
if self.is_main_process:
self._print_memory_stats()
def _param_init_fn(self, module: nn.Module):
"""Initialize parameters from pretrained weights."""
# This is called for each module to initialize from pretrained
# For production, load weights properly
module.to_empty(device=self.device)
# Initialize with small random values
for param in module.parameters():
if param.requires_grad:
nn.init.normal_(param, mean=0.0, std=0.02)
def _apply_activation_checkpointing(self):
"""Apply activation checkpointing to save memory."""
check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)
apply_activation_checkpointing(
self.model,
checkpoint_wrapper_fn=checkpoint_wrapper,
check_fn=check_fn,
)
if self.is_main_process:
print("Applied activation checkpointing")
def _print_memory_stats(self):
"""Print GPU memory statistics."""
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
def setup_optimizer(self):
"""Setup optimizer for FSDP training."""
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
betas=(0.9, 0.95),
weight_decay=0.1,
)
def get_dataloader(self, dataset, tokenizer) -> DataLoader:
"""Create distributed dataloader for causal LM."""
def collate_fn(batch):
texts = [item["text"] for item in batch]
encodings = tokenizer(
texts,
padding="max_length",
truncation=True,
max_length=self.config.max_length,
return_tensors="pt"
)
# For causal LM, labels are shifted input_ids
encodings["labels"] = encodings["input_ids"].clone()
return encodings
sampler = DistributedSampler(
dataset,
num_replicas=self.world_size,
rank=self.global_rank,
shuffle=True,
)
return DataLoader(
dataset,
batch_size=self.config.batch_size_per_gpu,
sampler=sampler,
collate_fn=collate_fn,
num_workers=4,
pin_memory=True,
)
def train_step(self, batch: Dict[str, torch.Tensor]) -> float:
"""Execute single training step."""
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
loss = outputs.loss / self.config.gradient_accumulation_steps
loss.backward()
return loss.item() * self.config.gradient_accumulation_steps
def optimizer_step(self):
"""Execute optimizer step."""
# Gradient clipping with FSDP
self.model.clip_grad_norm_(1.0)
self.optimizer.step()
self.optimizer.zero_grad()
def save_checkpoint(self, path: str, full_state: bool = False):
"""
Save FSDP checkpoint.
Args:
path: Save path
full_state: If True, gather full state to rank 0 (memory intensive)
If False, save sharded state (recommended for large models)
"""
if full_state:
# Gather full state dict to rank 0
save_policy = FullStateDictConfig(
offload_to_cpu=True,
rank0_only=True,
)
state_dict_type = StateDictType.FULL_STATE_DICT
else:
# Save sharded state dict
save_policy = ShardedStateDictConfig(
offload_to_cpu=True,
)
state_dict_type = StateDictType.SHARDED_STATE_DICT
with FSDP.state_dict_type(self.model, state_dict_type, save_policy):
state_dict = self.model.state_dict()
if full_state:
if self.is_main_process:
torch.save({"model_state_dict": state_dict}, path)
print(f"Saved full checkpoint to {path}")
else:
# Each rank saves its shard
shard_path = f"{path}_rank{self.global_rank}.pt"
torch.save({"model_state_dict": state_dict}, shard_path)
if self.is_main_process:
print(f"Saved sharded checkpoint to {path}_rank*.pt")
def cleanup(self):
"""Cleanup distributed resources."""
dist.destroy_process_group()
def train_fsdp(config: FSDPConfig):
"""Main FSDP training function."""
trainer = FSDPTrainer(config)
# Setup
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
trainer.setup_model()
trainer.setup_optimizer()
# For demo, use a small dataset
from datasets import load_dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1000]")
dataloader = trainer.get_dataloader(dataset, tokenizer)
# Training loop
global_step = 0
for epoch in range(config.num_epochs):
if trainer.is_main_process:
print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
trainer.model.train()
dataloader.sampler.set_epoch(epoch)
accumulation_loss = 0.0
for step, batch in enumerate(dataloader):
loss = trainer.train_step(batch)
accumulation_loss += loss
if (step + 1) % config.gradient_accumulation_steps == 0:
trainer.optimizer_step()
global_step += 1
if trainer.is_main_process and global_step % 10 == 0:
print(f"Step {global_step} | Loss: {accumulation_loss:.4f}")
trainer._print_memory_stats()
accumulation_loss = 0.0
# Save checkpoint
trainer.save_checkpoint(
os.path.join(config.checkpoint_dir, "fsdp_model"),
full_state=False # Use sharded saving for large models
)
trainer.cleanup()FSDP Launch Script
#!/bin/bash
# scripts/launch_fsdp.sh
NUM_GPUS=8
# Set for better NCCL performance
export NCCL_DEBUG=WARN
export CUDA_DEVICE_ORDER=PCI_BUS_ID
torchrun \
--nproc_per_node=$NUM_GPUS \
--master_port=29501 \
train.py \
--mode fsdp \
--config configs/fsdp_config.yamlPart 3: DeepSpeed Integration
DeepSpeed ZeRO Stages
┌─────────────────────────────────────────────────────────────────────────────┐
│ DeepSpeed ZeRO Stages │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ No ZeRO (Baseline) │ Memory: 4x (per GPU) │ Highest Memory │
│ ─────────────────────────────┼───────────────────────┼─────────────────── │
│ GPU 0: Full Model │ │ │
│ + Full Optimizer │ │ │
│ + Full Gradients │ │ │
│ GPU 1: Full Model │ │ │
│ + Full Optimizer │ │ │
│ + Full Gradients │ │ │
│ │ │ │
├───────────────────────────────┼───────────────────────┼─────────────────────┤
│ ZeRO Stage 1 │ Memory: ~2x │ Lower Memory │
│ ─────────────────────────────┼───────────────────────┼─────────────────── │
│ GPU 0: Full Model │ Optimizer states │ │
│ + Optimizer Shard │ sharded across │ │
│ + Full Gradients │ GPUs │ │
│ GPU 1: Full Model │ │ │
│ + Optimizer Shard │ │ │
│ + Full Gradients │ │ │
│ │ │ │
├───────────────────────────────┼───────────────────────┼─────────────────────┤
│ ZeRO Stage 2 │ Memory: ~1.5x │ Even Lower │
│ ─────────────────────────────┼───────────────────────┼─────────────────── │
│ GPU 0: Full Model │ + Gradients │ │
│ + Optimizer Shard │ sharded │ │
│ + Gradient Shard │ │ │
│ GPU 1: Full Model │ │ │
│ + Optimizer Shard │ │ │
│ + Gradient Shard │ │ │
│ │ │ │
├───────────────────────────────┼───────────────────────┼─────────────────────┤
│ ZeRO Stage 3 │ Memory: ~1x │ Lowest Memory ✓ │
│ ─────────────────────────────┼───────────────────────┼─────────────────── │
│ GPU 0: Model Shard │ + Parameters │ │
│ + Optimizer Shard │ sharded │ │
│ + Gradient Shard │ (gather on demand) │ │
│ GPU 1: Model Shard │ │ │
│ + Optimizer Shard │ │ │
│ + Gradient Shard │ │ │
│ │ │ │
└─────────────────────────────────────────────────────────────────────────────┘DeepSpeed Configuration
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [0.9, 0.95],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"gradient_clipping": 1.0,
"steps_per_print": 100,
"wall_clock_breakdown": false,
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": true,
"number_checkpoints": null
}
}DeepSpeed Trainer with HuggingFace
# src/deepspeed_trainer.py
"""
DeepSpeed integration for efficient large model training.
Supports ZeRO stages 1-3 with CPU/NVMe offloading.
"""
import os
import json
from typing import Optional, Dict, Any
from dataclasses import dataclass, asdict
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling,
)
from datasets import load_dataset
import deepspeed
@dataclass
class DeepSpeedConfig:
"""Configuration for DeepSpeed training."""
model_name: str = "meta-llama/Llama-2-7b-hf"
output_dir: str = "./output"
# Training hyperparameters
num_train_epochs: int = 1
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
gradient_accumulation_steps: int = 16
learning_rate: float = 1e-5
weight_decay: float = 0.1
warmup_ratio: float = 0.1
max_length: int = 2048
# DeepSpeed settings
zero_stage: int = 3
offload_optimizer: bool = True
offload_param: bool = True
# Precision
bf16: bool = True
fp16: bool = False
# Logging
logging_steps: int = 100
save_steps: int = 500
eval_steps: int = 500
def create_deepspeed_config(config: DeepSpeedConfig) -> Dict[str, Any]:
"""Generate DeepSpeed configuration dictionary."""
ds_config = {
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": [0.9, 0.95],
"eps": 1e-8,
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"gradient_clipping": 1.0,
"steps_per_print": config.logging_steps,
}
# ZeRO configuration
zero_config = {
"stage": config.zero_stage,
"overlap_comm": True,
"contiguous_gradients": True,
"reduce_bucket_size": "auto",
}
if config.zero_stage >= 2:
zero_config["allgather_bucket_size"] = 5e8
zero_config["reduce_scatter"] = True
if config.zero_stage == 3:
zero_config["stage3_prefetch_bucket_size"] = "auto"
zero_config["stage3_param_persistence_threshold"] = "auto"
zero_config["stage3_max_live_parameters"] = 1e9
zero_config["stage3_max_reuse_distance"] = 1e9
zero_config["stage3_gather_16bit_weights_on_model_save"] = True
# CPU offloading
if config.offload_optimizer:
zero_config["offload_optimizer"] = {
"device": "cpu",
"pin_memory": True
}
if config.offload_param and config.zero_stage == 3:
zero_config["offload_param"] = {
"device": "cpu",
"pin_memory": True
}
ds_config["zero_optimization"] = zero_config
# Precision
if config.bf16:
ds_config["bf16"] = {"enabled": True}
elif config.fp16:
ds_config["fp16"] = {
"enabled": True,
"auto_cast": True,
"loss_scale": 0,
"initial_scale_power": 16,
}
# Activation checkpointing for memory efficiency
ds_config["activation_checkpointing"] = {
"partition_activations": True,
"cpu_checkpointing": config.offload_param,
"contiguous_memory_optimization": True,
}
return ds_config
def train_with_deepspeed(config: DeepSpeedConfig):
"""Train model using DeepSpeed with HuggingFace Trainer."""
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
# Load model
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16 if config.bf16 else torch.float16,
use_cache=False, # Required for gradient checkpointing
)
# Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Load dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=config.max_length,
padding="max_length",
)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
remove_columns=dataset["train"].column_names,
)
# Data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False, # Causal LM, not masked LM
)
# Generate DeepSpeed config
ds_config = create_deepspeed_config(config)
# Save config to file (required by Trainer)
ds_config_path = os.path.join(config.output_dir, "ds_config.json")
os.makedirs(config.output_dir, exist_ok=True)
with open(ds_config_path, "w") as f:
json.dump(ds_config, f, indent=2)
# Training arguments
training_args = TrainingArguments(
output_dir=config.output_dir,
num_train_epochs=config.num_train_epochs,
per_device_train_batch_size=config.per_device_train_batch_size,
per_device_eval_batch_size=config.per_device_eval_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
warmup_ratio=config.warmup_ratio,
# DeepSpeed
deepspeed=ds_config_path,
# Precision
bf16=config.bf16,
fp16=config.fp16,
# Logging
logging_steps=config.logging_steps,
save_steps=config.save_steps,
eval_strategy="steps",
eval_steps=config.eval_steps,
# Other
save_total_limit=3,
load_best_model_at_end=True,
report_to=["tensorboard"],
gradient_checkpointing=True,
)
# Initialize trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
data_collator=data_collator,
tokenizer=tokenizer,
)
# Train
trainer.train()
# Save final model
trainer.save_model(os.path.join(config.output_dir, "final_model"))
return trainer
class DeepSpeedTrainerManual:
"""
Manual DeepSpeed training loop for more control.
Use this when you need custom training logic.
"""
def __init__(self, config: DeepSpeedConfig):
self.config = config
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
def setup(self):
"""Initialize DeepSpeed engine."""
# Load model
model = AutoModelForCausalLM.from_pretrained(
self.config.model_name,
torch_dtype=torch.bfloat16,
)
# Generate config
ds_config = create_deepspeed_config(self.config)
# Initialize DeepSpeed
self.model_engine, self.optimizer, _, self.lr_scheduler = deepspeed.initialize(
model=model,
config=ds_config,
)
def train_step(self, batch: Dict[str, torch.Tensor]) -> float:
"""Execute single training step."""
# Move to device
batch = {k: v.to(self.model_engine.device) for k, v in batch.items()}
# Forward pass
outputs = self.model_engine(**batch)
loss = outputs.loss
# Backward pass (DeepSpeed handles gradient accumulation)
self.model_engine.backward(loss)
# Optimizer step (DeepSpeed handles when to step based on accumulation)
self.model_engine.step()
return loss.item()
def save_checkpoint(self, path: str):
"""Save DeepSpeed checkpoint."""
self.model_engine.save_checkpoint(path)
def load_checkpoint(self, path: str):
"""Load DeepSpeed checkpoint."""
self.model_engine.load_checkpoint(path)
if __name__ == "__main__":
config = DeepSpeedConfig()
train_with_deepspeed(config)DeepSpeed Launch Script
#!/bin/bash
# scripts/launch_deepspeed.sh
NUM_GPUS=8
deepspeed \
--num_gpus=$NUM_GPUS \
train.py \
--mode deepspeed \
--config configs/deepspeed_config.jsonPart 4: Gradient Checkpointing
Memory-Compute Tradeoff
┌─────────────────────────────────────────────────────────────────────────────┐
│ Gradient Checkpointing: Memory vs Compute │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Standard Training Gradient Checkpointing │
│ ───────────────── ─────────────────────── │
│ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Forward Pass │ │ Forward Pass │ │
│ │ Store ALL │ │ Store CHECKPOINT │ │
│ │ activations │ │ activations only │ │
│ └─────────┬───────────┘ └─────────┬───────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Memory: O(n) │ │ Memory: O(√n) │ │
│ │ (n = num layers) │ │ (only checkpoints) │ │
│ └─────────┬───────────┘ └─────────┬───────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Backward Pass │ │ Backward Pass │ │
│ │ Compute: 1x forward │ │ Compute: ~1.33x │ │
│ │ │ │ (recompute between │ │
│ │ │ │ checkpoints) │ │
│ └─────────┬───────────┘ └─────────┬───────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────┐ ┌─────────────────────┐ │
│ │ Result: │ │ Result: │ │
│ │ ✗ High memory │ │ ✓ Low memory │ │
│ │ ✓ Fast training │ │ ✓ Scalable │ │
│ │ Limited batch size │ │ ~33% slower │ │
│ └─────────────────────┘ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Gradient Checkpointing Implementation
# src/gradient_checkpointing.py
"""
Gradient checkpointing utilities for memory-efficient training.
Trades compute for memory by recomputing activations during backward pass.
"""
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from typing import List, Optional, Callable
class CheckpointedModule(nn.Module):
"""
Wrapper that applies gradient checkpointing to a module.
Saves memory by not storing intermediate activations.
"""
def __init__(
self,
module: nn.Module,
use_checkpoint: bool = True,
preserve_rng_state: bool = True,
):
super().__init__()
self.module = module
self.use_checkpoint = use_checkpoint
self.preserve_rng_state = preserve_rng_state
def forward(self, *args, **kwargs):
if self.use_checkpoint and self.training:
# Note: checkpoint doesn't support kwargs directly
# We need to handle this carefully
return checkpoint(
self._forward_with_kwargs,
*args,
use_reentrant=False,
preserve_rng_state=self.preserve_rng_state,
**kwargs,
)
return self.module(*args, **kwargs)
def _forward_with_kwargs(self, *args, **kwargs):
return self.module(*args, **kwargs)
def apply_gradient_checkpointing_to_model(
model: nn.Module,
checkpoint_layer_types: List[type],
every_n_layers: int = 1,
) -> nn.Module:
"""
Apply gradient checkpointing to specific layer types in a model.
Args:
model: The model to modify
checkpoint_layer_types: Layer types to checkpoint (e.g., TransformerBlock)
every_n_layers: Checkpoint every N layers (default: every layer)
Returns:
Modified model with checkpointing enabled
"""
layer_count = 0
def apply_checkpoint(module: nn.Module):
nonlocal layer_count
for name, child in module.named_children():
if any(isinstance(child, layer_type) for layer_type in checkpoint_layer_types):
layer_count += 1
if layer_count % every_n_layers == 0:
setattr(module, name, CheckpointedModule(child))
else:
apply_checkpoint(child)
apply_checkpoint(model)
print(f"Applied checkpointing to {layer_count} layers")
return model
class SequentialCheckpoint(nn.Module):
"""
Sequential model with automatic gradient checkpointing.
Divides layers into segments and checkpoints each segment.
"""
def __init__(
self,
layers: nn.ModuleList,
num_segments: int = 4,
):
super().__init__()
self.layers = layers
self.num_segments = num_segments
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
return checkpoint_sequential(
self.layers,
self.num_segments,
x,
use_reentrant=False,
)
for layer in self.layers:
x = layer(x)
return x
class SelectiveCheckpointing:
"""
Selectively apply checkpointing based on memory pressure.
Useful for dynamic memory management during training.
"""
def __init__(
self,
memory_threshold_gb: float = 0.8,
check_frequency: int = 100,
):
self.memory_threshold = memory_threshold_gb
self.check_frequency = check_frequency
self.step_count = 0
self.checkpointing_enabled = False
def should_checkpoint(self) -> bool:
"""Check if checkpointing should be enabled based on memory."""
self.step_count += 1
if self.step_count % self.check_frequency != 0:
return self.checkpointing_enabled
# Check current memory usage
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
total = torch.cuda.get_device_properties(0).total_memory / 1e9
usage_ratio = allocated / total
self.checkpointing_enabled = usage_ratio > self.memory_threshold
return self.checkpointing_enabled
def get_checkpoint_fn(self) -> Callable:
"""Get the appropriate forward function based on memory state."""
if self.should_checkpoint():
return lambda fn, *args: checkpoint(fn, *args, use_reentrant=False)
return lambda fn, *args: fn(*args)
# Example: Custom Transformer with Checkpointing
class CheckpointedTransformerBlock(nn.Module):
"""Transformer block with optional gradient checkpointing."""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_checkpoint: bool = True,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.attention = nn.MultiheadAttention(
hidden_size,
num_heads,
batch_first=True,
)
self.norm1 = nn.LayerNorm(hidden_size)
self.norm2 = nn.LayerNorm(hidden_size)
mlp_hidden = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, hidden_size),
)
def _attention_block(self, x: torch.Tensor) -> torch.Tensor:
"""Attention sub-block."""
normed = self.norm1(x)
attn_out, _ = self.attention(normed, normed, normed)
return x + attn_out
def _mlp_block(self, x: torch.Tensor) -> torch.Tensor:
"""MLP sub-block."""
return x + self.mlp(self.norm2(x))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.use_checkpoint and self.training:
x = checkpoint(self._attention_block, x, use_reentrant=False)
x = checkpoint(self._mlp_block, x, use_reentrant=False)
else:
x = self._attention_block(x)
x = self._mlp_block(x)
return x
class CheckpointedTransformer(nn.Module):
"""Full transformer with gradient checkpointing."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
num_layers: int,
num_heads: int,
max_length: int = 2048,
use_checkpoint: bool = True,
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.pos_embedding = nn.Embedding(max_length, hidden_size)
self.layers = nn.ModuleList([
CheckpointedTransformerBlock(
hidden_size,
num_heads,
use_checkpoint=use_checkpoint,
)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_size)
self.head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
):
batch_size, seq_len = input_ids.shape
# Embeddings
positions = torch.arange(seq_len, device=input_ids.device)
x = self.embedding(input_ids) + self.pos_embedding(positions)
# Transformer layers
for layer in self.layers:
x = layer(x)
x = self.norm(x)
logits = self.head(x)
loss = None
if labels is not None:
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
)
return {"loss": loss, "logits": logits}
def memory_efficient_training_example():
"""Example of memory-efficient training with checkpointing."""
# Model configuration
config = {
"vocab_size": 32000,
"hidden_size": 4096,
"num_layers": 32,
"num_heads": 32,
"use_checkpoint": True,
}
# Create model
model = CheckpointedTransformer(**config).cuda()
# Compare memory usage
def measure_memory(use_ckpt: bool):
model.train()
for layer in model.layers:
layer.use_checkpoint = use_ckpt
# Dummy input
x = torch.randint(0, config["vocab_size"], (4, 512)).cuda()
labels = torch.randint(0, config["vocab_size"], (4, 512)).cuda()
torch.cuda.reset_peak_memory_stats()
outputs = model(x, labels=labels)
outputs["loss"].backward()
peak_memory = torch.cuda.max_memory_allocated() / 1e9
return peak_memory
# Without checkpointing
mem_no_ckpt = measure_memory(use_ckpt=False)
print(f"Memory without checkpointing: {mem_no_ckpt:.2f} GB")
# With checkpointing
mem_with_ckpt = measure_memory(use_ckpt=True)
print(f"Memory with checkpointing: {mem_with_ckpt:.2f} GB")
print(f"Memory savings: {(1 - mem_with_ckpt/mem_no_ckpt)*100:.1f}%")Part 5: Monitoring and Utilities
Distributed Monitoring
# src/monitoring.py
"""
Monitoring utilities for distributed training.
Handles logging, metrics, and debugging across processes.
"""
import os
import time
import json
from typing import Dict, Any, Optional
from datetime import datetime
from collections import defaultdict
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
class DistributedLogger:
"""
Logger that handles distributed training logging.
Only logs from rank 0 by default to avoid duplicate logs.
"""
def __init__(
self,
log_dir: str = "./logs",
experiment_name: Optional[str] = None,
rank: int = 0,
log_all_ranks: bool = False,
):
self.rank = rank
self.log_all_ranks = log_all_ranks
self.experiment_name = experiment_name or datetime.now().strftime("%Y%m%d_%H%M%S")
if self.should_log:
self.log_dir = os.path.join(log_dir, self.experiment_name)
os.makedirs(self.log_dir, exist_ok=True)
self.log_file = os.path.join(self.log_dir, f"training_rank{rank}.log")
self.writer = SummaryWriter(self.log_dir)
else:
self.writer = None
@property
def should_log(self) -> bool:
"""Check if this process should log."""
return self.log_all_ranks or self.rank == 0
def log(self, message: str, level: str = "INFO"):
"""Log a message."""
if not self.should_log:
return
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_message = f"[{timestamp}] [{level}] [Rank {self.rank}] {message}"
print(log_message)
with open(self.log_file, "a") as f:
f.write(log_message + "\n")
def log_metrics(
self,
metrics: Dict[str, float],
step: int,
prefix: str = "train",
):
"""Log metrics to TensorBoard."""
if not self.should_log or self.writer is None:
return
for key, value in metrics.items():
self.writer.add_scalar(f"{prefix}/{key}", value, step)
def log_histogram(
self,
name: str,
values: torch.Tensor,
step: int,
):
"""Log histogram to TensorBoard."""
if not self.should_log or self.writer is None:
return
self.writer.add_histogram(name, values, step)
def close(self):
"""Close the logger."""
if self.writer is not None:
self.writer.close()
class MetricsAggregator:
"""
Aggregates metrics across distributed processes.
Handles reduction operations for consistent metrics.
"""
def __init__(self, device: torch.device):
self.device = device
self.reset()
def reset(self):
"""Reset accumulated metrics."""
self.metrics = defaultdict(list)
self.counts = defaultdict(int)
def update(self, name: str, value: float, count: int = 1):
"""Add a metric value."""
self.metrics[name].append(value * count)
self.counts[name] += count
def compute(self, reduce_across_processes: bool = True) -> Dict[str, float]:
"""Compute averaged metrics, optionally reducing across processes."""
results = {}
for name in self.metrics:
total = sum(self.metrics[name])
count = self.counts[name]
if reduce_across_processes and dist.is_initialized():
# Create tensor for all-reduce
tensor = torch.tensor([total, count], device=self.device)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
total, count = tensor.tolist()
results[name] = total / max(count, 1)
return results
class GPUMonitor:
"""Monitor GPU utilization during training."""
def __init__(self, device_id: int = 0):
self.device_id = device_id
self.history = []
def get_stats(self) -> Dict[str, float]:
"""Get current GPU statistics."""
if not torch.cuda.is_available():
return {}
stats = {
"memory_allocated_gb": torch.cuda.memory_allocated(self.device_id) / 1e9,
"memory_reserved_gb": torch.cuda.memory_reserved(self.device_id) / 1e9,
"max_memory_allocated_gb": torch.cuda.max_memory_allocated(self.device_id) / 1e9,
}
# Try to get utilization (requires pynvml)
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_id)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
stats["gpu_utilization"] = util.gpu
stats["memory_utilization"] = util.memory
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
stats["memory_total_gb"] = mem_info.total / 1e9
stats["memory_free_gb"] = mem_info.free / 1e9
except ImportError:
pass
self.history.append(stats)
return stats
def get_summary(self) -> Dict[str, float]:
"""Get summary statistics."""
if not self.history:
return {}
summary = {}
for key in self.history[0]:
values = [h[key] for h in self.history]
summary[f"{key}_mean"] = sum(values) / len(values)
summary[f"{key}_max"] = max(values)
return summary
class ThroughputTracker:
"""Track training throughput (samples/second, tokens/second)."""
def __init__(self):
self.reset()
def reset(self):
"""Reset tracking."""
self.start_time = None
self.total_samples = 0
self.total_tokens = 0
self.step_times = []
def start_step(self):
"""Mark start of a training step."""
self.step_start = time.time()
def end_step(self, num_samples: int, num_tokens: Optional[int] = None):
"""Mark end of a training step."""
if self.start_time is None:
self.start_time = time.time()
step_time = time.time() - self.step_start
self.step_times.append(step_time)
self.total_samples += num_samples
if num_tokens is not None:
self.total_tokens += num_tokens
def get_throughput(self) -> Dict[str, float]:
"""Get throughput metrics."""
if not self.step_times:
return {}
elapsed = time.time() - self.start_time
metrics = {
"samples_per_second": self.total_samples / elapsed,
"step_time_mean": sum(self.step_times) / len(self.step_times),
"step_time_std": (
sum((t - sum(self.step_times)/len(self.step_times))**2
for t in self.step_times) / len(self.step_times)
) ** 0.5,
}
if self.total_tokens > 0:
metrics["tokens_per_second"] = self.total_tokens / elapsed
return metrics
def all_gather_object(obj: Any, world_size: int) -> list:
"""
Gather arbitrary Python objects from all processes.
Useful for gathering non-tensor data like strings or dicts.
"""
if not dist.is_initialized():
return [obj]
output = [None] * world_size
dist.all_gather_object(output, obj)
return output
def broadcast_object(obj: Any, src: int = 0) -> Any:
"""
Broadcast an arbitrary Python object from src to all processes.
"""
if not dist.is_initialized():
return obj
object_list = [obj]
dist.broadcast_object_list(object_list, src=src)
return object_list[0]
def print_rank_0(message: str, rank: Optional[int] = None):
"""Print only from rank 0."""
if rank is None:
rank = dist.get_rank() if dist.is_initialized() else 0
if rank == 0:
print(message)Utility Functions
# src/utils.py
"""
Utility functions for distributed training.
"""
import os
import random
import numpy as np
import torch
import torch.distributed as dist
from typing import Optional, Dict, Any
import yaml
def set_seed(seed: int, rank: int = 0):
"""Set random seeds for reproducibility."""
seed = seed + rank # Different seed per process
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# For deterministic operations (may slow down training)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
def get_world_info() -> Dict[str, int]:
"""Get distributed training world information."""
if dist.is_initialized():
return {
"world_size": dist.get_world_size(),
"rank": dist.get_rank(),
"local_rank": int(os.environ.get("LOCAL_RANK", 0)),
}
return {
"world_size": 1,
"rank": 0,
"local_rank": 0,
}
def setup_distributed(backend: str = "nccl") -> Dict[str, int]:
"""Initialize distributed training."""
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
dist.init_process_group(backend=backend)
world_info = get_world_info()
# Set device
if torch.cuda.is_available():
torch.cuda.set_device(world_info["local_rank"])
return world_info
def cleanup_distributed():
"""Clean up distributed resources."""
if dist.is_initialized():
dist.destroy_process_group()
def load_config(config_path: str) -> Dict[str, Any]:
"""Load configuration from YAML file."""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return config
def save_config(config: Dict[str, Any], path: str):
"""Save configuration to YAML file."""
with open(path, "w") as f:
yaml.dump(config, f, default_flow_style=False)
def get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps: int,
num_training_steps: int,
last_epoch: int = -1,
):
"""Create linear warmup scheduler."""
from torch.optim.lr_scheduler import LambdaLR
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0,
float(num_training_steps - current_step) /
float(max(1, num_training_steps - num_warmup_steps))
)
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
def count_parameters(model: torch.nn.Module) -> Dict[str, int]:
"""Count model parameters."""
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return {
"total_parameters": total,
"trainable_parameters": trainable,
"frozen_parameters": total - trainable,
}
def get_parameter_dtype(model: torch.nn.Module) -> torch.dtype:
"""Get the dtype of model parameters."""
for param in model.parameters():
return param.dtype
return torch.float32
def compute_effective_batch_size(
batch_size_per_gpu: int,
gradient_accumulation_steps: int,
world_size: int,
) -> int:
"""Compute effective batch size across all GPUs and accumulation steps."""
return batch_size_per_gpu * gradient_accumulation_steps * world_size
def estimate_memory_usage(
model: torch.nn.Module,
batch_size: int,
seq_length: int,
dtype: torch.dtype = torch.float32,
) -> Dict[str, float]:
"""Estimate memory usage for training."""
# Parameter memory
param_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
# Gradient memory (same as parameters)
grad_bytes = param_bytes
# Optimizer state (AdamW: 2x for momentum and variance)
optimizer_bytes = 2 * param_bytes
# Activation memory (rough estimate)
# This is highly model-dependent
hidden_size = getattr(model.config, "hidden_size", 768)
num_layers = getattr(model.config, "num_hidden_layers", 12)
dtype_size = {
torch.float32: 4,
torch.float16: 2,
torch.bfloat16: 2,
}.get(dtype, 4)
activation_bytes = (
batch_size * seq_length * hidden_size * num_layers * dtype_size * 2
)
return {
"parameters_gb": param_bytes / 1e9,
"gradients_gb": grad_bytes / 1e9,
"optimizer_gb": optimizer_bytes / 1e9,
"activations_gb": activation_bytes / 1e9,
"total_gb": (param_bytes + grad_bytes + optimizer_bytes + activation_bytes) / 1e9,
}Part 6: Main Training Script
# train.py
"""
Main training script supporting DDP, FSDP, and DeepSpeed.
"""
import argparse
import os
import yaml
import torch
from src.ddp_trainer import DDPConfig, train_ddp
from src.fsdp_trainer import FSDPConfig, train_fsdp
from src.deepspeed_trainer import DeepSpeedConfig, train_with_deepspeed
from src.utils import setup_distributed, cleanup_distributed, set_seed
def parse_args():
parser = argparse.ArgumentParser(description="Distributed Training")
parser.add_argument(
"--mode",
type=str,
choices=["ddp", "fsdp", "deepspeed"],
default="ddp",
help="Training mode",
)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to config file",
)
parser.add_argument(
"--model_name",
type=str,
default=None,
help="Override model name from config",
)
parser.add_argument(
"--output_dir",
type=str,
default="./output",
help="Output directory",
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed",
)
return parser.parse_args()
def load_config(args):
"""Load configuration based on mode and config file."""
config = {}
if args.config and os.path.exists(args.config):
with open(args.config, "r") as f:
if args.config.endswith(".json"):
import json
config = json.load(f)
else:
config = yaml.safe_load(f)
# Override with command line args
if args.model_name:
config["model_name"] = args.model_name
if args.output_dir:
config["output_dir"] = args.output_dir
config["checkpoint_dir"] = os.path.join(args.output_dir, "checkpoints")
return config
def main():
args = parse_args()
# Setup distributed
world_info = setup_distributed()
# Set seed
set_seed(args.seed, rank=world_info["rank"])
# Load config
config_dict = load_config(args)
try:
if args.mode == "ddp":
config = DDPConfig(**config_dict)
train_ddp(config)
elif args.mode == "fsdp":
config = FSDPConfig(**config_dict)
train_fsdp(config)
elif args.mode == "deepspeed":
config = DeepSpeedConfig(**config_dict)
train_with_deepspeed(config)
finally:
cleanup_distributed()
if world_info["rank"] == 0:
print("\nTraining completed successfully!")
if __name__ == "__main__":
main()Production Deployment
Docker Setup
# Dockerfile
FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel
# Install system dependencies
RUN apt-get update && apt-get install -y \
git \
wget \
openssh-server \
&& rm -rf /var/lib/apt/lists/*
# Setup SSH for multi-node training
RUN mkdir /var/run/sshd
RUN echo 'root:distributed' | chpasswd
RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
# Install Python dependencies
COPY requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir -r /app/requirements.txt
# Copy application code
WORKDIR /app
COPY . /app
# Environment variables for distributed training
ENV NCCL_DEBUG=WARN
ENV NCCL_IB_DISABLE=0
ENV NCCL_SOCKET_IFNAME=eth0
EXPOSE 22 29500
CMD ["/usr/sbin/sshd", "-D"]# docker-compose.yml
version: '3.8'
services:
trainer-node-0:
build: .
image: distributed-training:latest
runtime: nvidia
environment:
- NVIDIA_VISIBLE_DEVICES=all
- MASTER_ADDR=trainer-node-0
- MASTER_PORT=29500
- WORLD_SIZE=2
- RANK=0
- LOCAL_RANK=0
volumes:
- ./data:/app/data
- ./checkpoints:/app/checkpoints
- ./logs:/app/logs
networks:
- training-network
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
command: >
torchrun
--nproc_per_node=4
--nnodes=2
--node_rank=0
--master_addr=trainer-node-0
--master_port=29500
train.py --mode fsdp
trainer-node-1:
build: .
image: distributed-training:latest
runtime: nvidia
environment:
- NVIDIA_VISIBLE_DEVICES=all
- MASTER_ADDR=trainer-node-0
- MASTER_PORT=29500
- WORLD_SIZE=2
- RANK=1
- LOCAL_RANK=0
volumes:
- ./data:/app/data
- ./checkpoints:/app/checkpoints
- ./logs:/app/logs
networks:
- training-network
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
depends_on:
- trainer-node-0
command: >
torchrun
--nproc_per_node=4
--nnodes=2
--node_rank=1
--master_addr=trainer-node-0
--master_port=29500
train.py --mode fsdp
networks:
training-network:
driver: bridgeCloud Deployment (AWS)
# scripts/launch_aws.py
"""
Launch distributed training on AWS EC2 instances.
"""
import boto3
import time
from typing import List
def launch_training_cluster(
num_nodes: int = 2,
instance_type: str = "p4d.24xlarge",
ami_id: str = "ami-xxx", # Deep Learning AMI
key_name: str = "your-key",
security_group: str = "sg-xxx",
subnet_id: str = "subnet-xxx",
):
"""Launch EC2 instances for distributed training."""
ec2 = boto3.client("ec2")
# User data script for each node
def get_user_data(node_rank: int, master_ip: str = None) -> str:
if node_rank == 0:
return f"""#!/bin/bash
cd /home/ubuntu/training
source activate pytorch
export MASTER_ADDR=$(hostname -I | awk '{{print $1}}')
export MASTER_PORT=29500
export WORLD_SIZE={num_nodes}
export RANK=0
# Start training
torchrun --nproc_per_node=8 --nnodes={num_nodes} --node_rank=0 \\
--master_addr=$MASTER_ADDR --master_port=29500 \\
train.py --mode fsdp
"""
else:
return f"""#!/bin/bash
cd /home/ubuntu/training
source activate pytorch
export MASTER_ADDR={master_ip}
export MASTER_PORT=29500
export WORLD_SIZE={num_nodes}
export RANK={node_rank}
# Wait for master
sleep 60
# Start training
torchrun --nproc_per_node=8 --nnodes={num_nodes} --node_rank={node_rank} \\
--master_addr=$MASTER_ADDR --master_port=29500 \\
train.py --mode fsdp
"""
# Launch master node first
master_response = ec2.run_instances(
ImageId=ami_id,
InstanceType=instance_type,
MinCount=1,
MaxCount=1,
KeyName=key_name,
SecurityGroupIds=[security_group],
SubnetId=subnet_id,
UserData=get_user_data(0),
TagSpecifications=[{
"ResourceType": "instance",
"Tags": [{"Key": "Name", "Value": "training-master"}]
}],
)
master_instance_id = master_response["Instances"][0]["InstanceId"]
# Wait for master to be running
waiter = ec2.get_waiter("instance_running")
waiter.wait(InstanceIds=[master_instance_id])
# Get master IP
master_info = ec2.describe_instances(InstanceIds=[master_instance_id])
master_ip = master_info["Reservations"][0]["Instances"][0]["PrivateIpAddress"]
print(f"Master node started: {master_instance_id} ({master_ip})")
# Launch worker nodes
worker_ids = []
for rank in range(1, num_nodes):
worker_response = ec2.run_instances(
ImageId=ami_id,
InstanceType=instance_type,
MinCount=1,
MaxCount=1,
KeyName=key_name,
SecurityGroupIds=[security_group],
SubnetId=subnet_id,
UserData=get_user_data(rank, master_ip),
TagSpecifications=[{
"ResourceType": "instance",
"Tags": [{"Key": "Name", "Value": f"training-worker-{rank}"}]
}],
)
worker_id = worker_response["Instances"][0]["InstanceId"]
worker_ids.append(worker_id)
print(f"Worker node {rank} started: {worker_id}")
return master_instance_id, worker_ids
if __name__ == "__main__":
master_id, worker_ids = launch_training_cluster(num_nodes=2)
print(f"\nCluster launched successfully!")
print(f"Master: {master_id}")
print(f"Workers: {worker_ids}")Common Issues and Solutions
Debugging Distributed Training
# scripts/debug_distributed.py
"""
Debugging utilities for distributed training issues.
"""
import os
import socket
import torch
import torch.distributed as dist
def check_environment():
"""Check distributed environment variables."""
env_vars = [
"RANK", "WORLD_SIZE", "LOCAL_RANK",
"MASTER_ADDR", "MASTER_PORT",
"NCCL_DEBUG", "NCCL_SOCKET_IFNAME",
]
print("Environment Variables:")
for var in env_vars:
value = os.environ.get(var, "NOT SET")
print(f" {var}: {value}")
print(f"\nHostname: {socket.gethostname()}")
print(f"IP Address: {socket.gethostbyname(socket.gethostname())}")
def check_gpu_connectivity():
"""Check GPU availability and CUDA status."""
print("\nGPU Status:")
print(f" CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f" CUDA version: {torch.version.cuda}")
print(f" GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
print(f" GPU {i}: {props.name}")
print(f" Memory: {props.total_memory / 1e9:.1f} GB")
print(f" Compute capability: {props.major}.{props.minor}")
def check_nccl_connectivity():
"""Test NCCL communication between GPUs."""
if not torch.cuda.is_available():
print("CUDA not available, skipping NCCL test")
return
print("\nNCCL Connectivity Test:")
try:
# Simple all-reduce test
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
# Create tensor on each GPU
tensor = torch.ones(1000, device=f"cuda:{rank % torch.cuda.device_count()}")
tensor *= rank
# All-reduce
dist.all_reduce(tensor)
expected = sum(range(world_size)) * 1000
actual = tensor.sum().item()
if abs(actual - expected) < 1e-5:
print(f" Rank {rank}: NCCL all-reduce OK")
else:
print(f" Rank {rank}: NCCL all-reduce FAILED")
print(f" Expected: {expected}, Got: {actual}")
else:
print(" Distributed not initialized")
except Exception as e:
print(f" NCCL test failed: {e}")
def diagnose_common_issues():
"""Print common issues and solutions."""
print("\n" + "="*60)
print("Common Issues and Solutions")
print("="*60)
issues = [
{
"issue": "NCCL timeout",
"symptoms": "Process hangs during all_reduce",
"solutions": [
"Check firewall rules (port 29500 and NCCL ports)",
"Set NCCL_DEBUG=INFO for more details",
"Try NCCL_SOCKET_IFNAME=eth0 (or correct interface)",
"Increase timeout: NCCL_TIMEOUT=1800",
]
},
{
"issue": "OOM (Out of Memory)",
"symptoms": "CUDA out of memory error",
"solutions": [
"Reduce batch size per GPU",
"Enable gradient checkpointing",
"Use FSDP or DeepSpeed ZeRO-3",
"Enable CPU offloading",
]
},
{
"issue": "Slow training",
"symptoms": "Low GPU utilization",
"solutions": [
"Increase batch size",
"Use more data loader workers",
"Enable pin_memory=True",
"Check network bandwidth between nodes",
]
},
{
"issue": "Loss NaN",
"symptoms": "NaN loss values",
"solutions": [
"Lower learning rate",
"Add gradient clipping",
"Check for inf/nan in inputs",
"Use fp32 for loss computation",
]
},
]
for item in issues:
print(f"\n{item['issue']}")
print(f" Symptoms: {item['symptoms']}")
print(" Solutions:")
for solution in item["solutions"]:
print(f" - {solution}")
if __name__ == "__main__":
check_environment()
check_gpu_connectivity()
# Initialize for NCCL test
if "RANK" in os.environ:
dist.init_process_group(backend="nccl")
check_nccl_connectivity()
dist.destroy_process_group()
diagnose_common_issues()Performance Comparison
| Method | Memory per GPU | Communication | Best For |
|---|---|---|---|
| DDP | Full model | Gradient AllReduce | Models that fit on 1 GPU |
| FSDP | ~1/N model | Params + Grads | Large models, single node |
| DeepSpeed ZeRO-1 | Full model | Optimizer sharding | Large batch training |
| DeepSpeed ZeRO-2 | Full model | + Gradient sharding | Moderate memory savings |
| DeepSpeed ZeRO-3 | ~1/N model | + Param sharding | Very large models |
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| DDP | DistributedDataParallel - replicate model on each GPU, sync gradients | Simplest approach when model fits on one GPU |
| FSDP | Fully Sharded Data Parallel - shard params, grads, optimizer across GPUs | Train models larger than single GPU memory |
| AllReduce | Collective op that sums tensors across GPUs | Synchronizes gradients so all GPUs have same weights |
| World Size | Total number of GPUs across all nodes | Determines effective batch size and sharding |
| Local Rank | GPU index within a single node (0, 1, 2...) | Used to assign each process to correct GPU |
| ZeRO | Zero Redundancy Optimizer - DeepSpeed's memory optimization | Stage 1-3 progressively shard more state |
| Gradient Checkpointing | Recompute activations during backward instead of storing | Trade ~33% more compute for O(√n) memory |
| Gradient Accumulation | Accumulate gradients over mini-batches before optimizer step | Simulate larger batch sizes with limited memory |
| Mixed Precision | Use FP16/BF16 for compute, FP32 for accumulation | 2x memory savings, faster on modern GPUs |
| CPU Offload | Move optimizer states or params to CPU RAM | Train even larger models at cost of speed |
Next Steps
After mastering distributed training:
- Custom Transformer - Build transformer architecture from scratch
- Production Deployment - Deploy to Kubernetes clusters
- Mixed Precision Training - Advanced AMP strategies
- Pipeline Parallelism - For extremely large models