Speculative Decoding
Accelerate large language models using small draft models
Speculative Decoding
TL;DR
Accelerate LLM inference by having a fast draft model generate K tokens, then the target model verifies all K in one forward pass (memory-bound, not compute-bound). Accept tokens where draft matches target distribution, resample on rejection. Speedup ≈ K × acceptance_rate. Use adaptive K based on acceptance rate for optimal throughput.
Use small language models to dramatically accelerate large model inference while maintaining output quality.
Project Overview
| Aspect | Details |
|---|---|
| Difficulty | Advanced |
| Time | 3-4 days |
| Prerequisites | PyTorch, LLM inference, probability theory |
| Learning Outcomes | Speculative decoding, draft models, verification, performance optimization |
What You'll Learn
- Understand speculative decoding theory and algorithms
- Implement draft model selection strategies
- Build verification and rejection sampling systems
- Optimize for different hardware configurations
- Benchmark and tune for production workloads
How Speculative Decoding Works
┌─────────────────────────────────────────────────────────────────────────────┐
│ SPECULATIVE DECODING SEQUENCE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ User Draft Model (SLM) Target Model (LLM) │
│ │ │ │ │
│ │ Input prompt │ │ │
│ │─────────────────────►│ │ │
│ │ │ │ │
│ │ ┌──────┴──────┐ │ │
│ │ │ Generate K │ │ │
│ │ │ tokens fast │ │ │
│ │ │ (autoregr.) │ │ │
│ │ └──────┬──────┘ │ │
│ │ │ │ │
│ │ │ K draft tokens │ │
│ │ │──────────────────────────►│ │
│ │ │ │ │
│ │ │ ┌──────┴──────┐ │
│ │ │ │ Verify ALL │ │
│ │ │ │ K tokens in │ │
│ │ │ │ ONE forward │ │
│ │ │ │ pass │ │
│ │ │ └──────┬──────┘ │
│ │ │ │ │
│ │◄─────────────────────────────────────────────────│ │
│ │ Accepted tokens + correction │ │
│ │ │ │
├─────────────────────────────────────────────────────────────────────────────┤
│ KEY INSIGHT: Verifying K tokens costs ~same as generating 1 (memory-bound) │
└─────────────────────────────────────────────────────────────────────────────┘Architecture Overview
┌─────────────────────────────────────────────────────────────────────────────┐
│ SPECULATIVE DECODING ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ INPUT PROCESSING │ │
│ │ User Prompt ──► Context Window │ │
│ └──────────────────────────────┬──────────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────┴─────────────────────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────────┐ ┌─────────────────────────┐ │
│ │ DRAFT MODEL (Fast) │ │ TARGET MODEL (Accurate) │ │
│ │ │ │ │ │
│ │ Small LLM (1-3B) │ K tokens │ Large LLM (7B-70B) │ │
│ │ │ │────────────►│ │ │ │
│ │ ▼ │ │ ▼ │ │
│ │ Generate K tokens │ │ Parallel verification │ │
│ │ (autoregressive) │ │ (single forward pass) │ │
│ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │
│ │ Draft probs: p_d(x) │ │ Target probs: p_t(x) │ │
│ │ │ │ │ │
│ └────────────┬────────────┘ └────────────┬────────────┘ │
│ │ │ │
│ └───────────────┬───────────────────────┘ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ VERIFICATION (Rejection Sampling) │ │
│ │ │ │
│ │ Compare: acceptance_prob = min(1, p_t(x) / p_d(x)) │ │
│ │ │ │ │
│ │ ┌───────────────┴───────────────┐ │ │
│ │ ▼ ▼ │ │
│ │ r < acceptance r >= acceptance │ │
│ │ │ │ │ │
│ │ ▼ ▼ │ │
│ │ ACCEPT token REJECT & resample from │ │
│ │ (continue) p_t - p_d (adjusted dist) │ │
│ │ │ │ │ │
│ │ └───────────────┬───────────────┘ │ │
│ │ ▼ │ │
│ │ Output Tokens │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Project Setup
Dependencies
# Create project directory
mkdir speculative-decoding && cd speculative-decoding
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install PyTorch with CUDA
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# Transformers and accelerate
pip install transformers accelerate
# For llama.cpp implementation
pip install llama-cpp-python
# Visualization and benchmarking
pip install plotly pandas tqdm numpyModel Pairs for Speculative Decoding
| Target Model | Draft Model | Expected Speedup |
|---|---|---|
| Llama 3 70B | Llama 3 8B | 2-3x |
| Qwen2.5 72B | Qwen2.5 7B | 2-3x |
| Phi-3 Medium | Phi-3 Mini | 1.5-2x |
| Mistral 7B | SmolLM 360M | 1.5-2x |
| Qwen2.5 7B | Qwen2.5 0.5B | 1.3-1.8x |
Part 1: Understanding Speculative Decoding
The core algorithm with mathematical foundations.
# core/speculative_decoding.py
"""
Core speculative decoding implementation.
The key insight: LLM inference is memory-bound, not compute-bound.
A single forward pass can verify K tokens as cheaply as generating 1.
"""
import torch
import torch.nn.functional as F
from typing import Tuple, Optional
from dataclasses import dataclass
@dataclass
class SpeculativeConfig:
"""Configuration for speculative decoding."""
num_speculative_tokens: int = 5 # K - number of draft tokens per iteration
temperature: float = 1.0
top_p: float = 0.9
top_k: int = 50
class SpeculativeDecoder:
"""
Speculative decoding implementation.
Algorithm:
1. Draft model generates K tokens autoregressively
2. Target model verifies all K tokens in one forward pass
3. Accept tokens that match target distribution
4. Resample rejected token from adjusted distribution
"""
def __init__(
self,
target_model,
draft_model,
tokenizer,
config: SpeculativeConfig = None
):
self.target = target_model
self.draft = draft_model
self.tokenizer = tokenizer
self.config = config or SpeculativeConfig()
# Ensure models are in eval mode
self.target.eval()
self.draft.eval()
# Device
self.device = next(target_model.parameters()).device
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100
) -> Tuple[torch.Tensor, dict]:
"""
Generate tokens using speculative decoding.
Args:
input_ids: Input token IDs [batch_size, seq_len]
max_new_tokens: Maximum tokens to generate
Returns:
Tuple of (generated_ids, metrics)
"""
batch_size = input_ids.shape[0]
assert batch_size == 1, "Batch size must be 1 for speculative decoding"
# Metrics tracking
metrics = {
"total_tokens": 0,
"accepted_tokens": 0,
"draft_calls": 0,
"target_calls": 0,
}
generated = input_ids.clone()
tokens_generated = 0
while tokens_generated < max_new_tokens:
# Step 1: Generate K draft tokens
draft_tokens, draft_probs = self._draft_generate(
generated,
self.config.num_speculative_tokens
)
metrics["draft_calls"] += self.config.num_speculative_tokens
# Step 2: Verify with target model (single forward pass!)
target_probs = self._target_verify(generated, draft_tokens)
metrics["target_calls"] += 1
# Step 3: Accept/reject tokens
accepted, new_token = self._verify_and_accept(
draft_tokens,
draft_probs,
target_probs
)
# Step 4: Update generated sequence
num_accepted = len(accepted)
metrics["accepted_tokens"] += num_accepted
metrics["total_tokens"] += num_accepted + 1 # +1 for resampled token
if num_accepted > 0:
generated = torch.cat([generated, accepted.unsqueeze(0)], dim=1)
generated = torch.cat([generated, new_token.unsqueeze(0).unsqueeze(0)], dim=1)
tokens_generated += num_accepted + 1
# Check for EOS
if new_token.item() == self.tokenizer.eos_token_id:
break
# Calculate acceptance rate
metrics["acceptance_rate"] = metrics["accepted_tokens"] / max(
metrics["total_tokens"], 1
)
return generated, metrics
def _draft_generate(
self,
input_ids: torch.Tensor,
num_tokens: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generate K tokens from draft model.
Returns:
Tuple of (token_ids, probabilities)
"""
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
for _ in range(num_tokens):
outputs = self.draft(current_ids)
logits = outputs.logits[:, -1, :]
# Apply temperature
if self.config.temperature != 1.0:
logits = logits / self.config.temperature
# Apply top-k/top-p filtering
probs = self._apply_sampling(logits)
# Sample
token = torch.multinomial(probs, num_samples=1)
draft_tokens.append(token.item())
draft_probs.append(probs[0, token.item()].item())
current_ids = torch.cat([current_ids, token], dim=1)
return (
torch.tensor(draft_tokens, device=self.device),
torch.tensor(draft_probs, device=self.device)
)
def _target_verify(
self,
prefix: torch.Tensor,
draft_tokens: torch.Tensor
) -> torch.Tensor:
"""
Get target model probabilities for all draft tokens in ONE forward pass.
This is the key efficiency gain - we verify K tokens with cost of 1.
"""
# Concatenate prefix with draft tokens
full_input = torch.cat([prefix, draft_tokens.unsqueeze(0)], dim=1)
# Single forward pass
outputs = self.target(full_input)
logits = outputs.logits
# Extract logits for positions after the prefix
# We need probabilities at positions [prefix_len-1, prefix_len, ..., prefix_len+K-1]
prefix_len = prefix.shape[1]
relevant_logits = logits[:, prefix_len-1:prefix_len+len(draft_tokens), :]
# Apply temperature
if self.config.temperature != 1.0:
relevant_logits = relevant_logits / self.config.temperature
# Get probabilities
probs = F.softmax(relevant_logits, dim=-1)
# Extract probability of each draft token at its position
target_probs = []
for i, token in enumerate(draft_tokens):
target_probs.append(probs[0, i, token].item())
# Also get full distribution for resampling
self._last_target_probs = probs[0, -1, :] # For rejection resampling
return torch.tensor(target_probs, device=self.device)
def _verify_and_accept(
self,
draft_tokens: torch.Tensor,
draft_probs: torch.Tensor,
target_probs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Verify draft tokens and accept/reject based on probability ratio.
Uses rejection sampling to maintain exact target distribution.
"""
accepted = []
for i, (token, p_draft, p_target) in enumerate(
zip(draft_tokens, draft_probs, target_probs)
):
# Acceptance probability: min(1, p_target / p_draft)
acceptance_prob = min(1.0, p_target.item() / max(p_draft.item(), 1e-10))
# Random acceptance
if torch.rand(1).item() < acceptance_prob:
accepted.append(token.item())
else:
# Rejection - need to resample from adjusted distribution
break
accepted_tokens = torch.tensor(accepted, device=self.device)
# Resample from adjusted distribution
# p_adjusted = max(0, p_target - p_draft) / sum(max(0, p_target - p_draft))
new_token = self._resample_token(draft_probs, target_probs, len(accepted))
return accepted_tokens, new_token
def _resample_token(
self,
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
position: int
) -> torch.Tensor:
"""
Resample token from adjusted distribution after rejection.
"""
# Get full target distribution at rejection position
# This was cached in _target_verify
p_target = self._last_target_probs
if position < len(draft_probs):
# Calculate adjusted distribution
# For simplicity, just sample from target distribution
# Full implementation would use: max(0, p_target - p_draft)
probs = self._apply_sampling(p_target.unsqueeze(0))
new_token = torch.multinomial(probs, num_samples=1)
return new_token.squeeze()
else:
# All draft tokens accepted, sample next from target
probs = self._apply_sampling(p_target.unsqueeze(0))
new_token = torch.multinomial(probs, num_samples=1)
return new_token.squeeze()
def _apply_sampling(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply top-k and top-p filtering."""
# Top-k
if self.config.top_k > 0:
indices_to_remove = logits < torch.topk(logits, self.config.top_k)[0][..., -1, None]
logits[indices_to_remove] = float("-inf")
# Top-p (nucleus sampling)
if self.config.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > self.config.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float("-inf")
return F.softmax(logits, dim=-1)
# Example usage
if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load models (use smaller ones for demo)
print("Loading models...")
# For demo, using the same model as both draft and target
# In practice, use different sized models from the same family
model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
draft_model = AutoModelForCausalLM.from_pretrained(model_name)
target_model = AutoModelForCausalLM.from_pretrained(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
draft_model.to(device)
target_model.to(device)
# Initialize decoder
config = SpeculativeConfig(num_speculative_tokens=4)
decoder = SpeculativeDecoder(target_model, draft_model, tokenizer, config)
# Generate
prompt = "The key to efficient inference is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
output_ids, metrics = decoder.generate(input_ids, max_new_tokens=50)
print(f"\nPrompt: {prompt}")
print(f"Output: {tokenizer.decode(output_ids[0], skip_special_tokens=True)}")
print(f"\nMetrics:")
print(f" Acceptance rate: {metrics['acceptance_rate']:.2%}")
print(f" Draft calls: {metrics['draft_calls']}")
print(f" Target calls: {metrics['target_calls']}")★ Insight ─────────────────────────────────────
Why Speculative Decoding Works: LLM inference is memory-bandwidth limited, not compute-limited. The KV cache reads dominate latency. Verifying K tokens requires the same memory reads as generating 1 token, because we're just running K positions in parallel through a single forward pass. The speedup is proportional to K × acceptance_rate.
─────────────────────────────────────────────────
Part 2: Draft Model Selection
Choosing and training optimal draft models.
# draft/selection.py
"""
Draft model selection and evaluation for speculative decoding.
"""
import torch
import torch.nn.functional as F
from typing import List, Tuple, Dict
from dataclasses import dataclass
import numpy as np
from tqdm import tqdm
@dataclass
class DraftModelMetrics:
"""Metrics for evaluating draft model quality."""
model_name: str
parameters: int
acceptance_rate: float
tokens_per_second: float
theoretical_speedup: float
memory_mb: float
class DraftModelEvaluator:
"""
Evaluate and select draft models for speculative decoding.
Key metrics:
1. Acceptance rate - how often draft tokens match target
2. Speed - draft model inference latency
3. Memory - additional memory required
4. Theoretical speedup - K * acceptance_rate
"""
def __init__(self, target_model, tokenizer, device: str = "cuda"):
self.target = target_model
self.tokenizer = tokenizer
self.device = device
self.target.to(device)
self.target.eval()
def evaluate_draft_model(
self,
draft_model,
eval_prompts: List[str],
num_tokens: int = 100,
k_values: List[int] = [3, 5, 7, 10]
) -> Dict[int, DraftModelMetrics]:
"""
Evaluate a draft model with different K values.
Args:
draft_model: The draft model to evaluate
eval_prompts: Prompts for evaluation
num_tokens: Tokens to generate per prompt
k_values: Different K values to test
Returns:
Dictionary mapping K to metrics
"""
draft_model.to(self.device)
draft_model.eval()
results = {}
for k in k_values:
print(f"\nEvaluating with K={k}...")
all_acceptance_rates = []
all_latencies = []
for prompt in tqdm(eval_prompts):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
acceptance_rate, latency = self._measure_acceptance(
draft_model, input_ids, k, num_tokens
)
all_acceptance_rates.append(acceptance_rate)
all_latencies.append(latency)
avg_acceptance = np.mean(all_acceptance_rates)
avg_latency = np.mean(all_latencies)
# Calculate theoretical speedup
# Speedup = (k * acceptance_rate + 1) / (k * t_draft/t_target + 1)
# Simplified: k * acceptance_rate (when draft is much faster)
theoretical_speedup = k * avg_acceptance + 1
# Get memory usage
memory_mb = self._get_memory_usage(draft_model)
results[k] = DraftModelMetrics(
model_name=getattr(draft_model.config, '_name_or_path', 'unknown'),
parameters=sum(p.numel() for p in draft_model.parameters()),
acceptance_rate=avg_acceptance,
tokens_per_second=1 / avg_latency if avg_latency > 0 else 0,
theoretical_speedup=theoretical_speedup,
memory_mb=memory_mb
)
return results
@torch.no_grad()
def _measure_acceptance(
self,
draft_model,
input_ids: torch.Tensor,
k: int,
num_tokens: int
) -> Tuple[float, float]:
"""Measure acceptance rate and latency."""
import time
total_accepted = 0
total_proposed = 0
total_time = 0
current_ids = input_ids.clone()
tokens_generated = 0
while tokens_generated < num_tokens:
# Generate K draft tokens
start_time = time.time()
draft_tokens, draft_probs = self._generate_draft_tokens(
draft_model, current_ids, k
)
draft_time = time.time() - start_time
# Get target probabilities
target_probs = self._get_target_probs(current_ids, draft_tokens)
# Calculate acceptance
for i, (d_prob, t_prob) in enumerate(zip(draft_probs, target_probs)):
acceptance_prob = min(1.0, t_prob / max(d_prob, 1e-10))
if torch.rand(1).item() < acceptance_prob:
total_accepted += 1
else:
break
total_proposed += 1
total_proposed += 1 # For the resampled token
total_time += draft_time
# Update for next iteration (simplified)
tokens_generated += k
current_ids = torch.cat([current_ids, draft_tokens.unsqueeze(0)], dim=1)
acceptance_rate = total_accepted / max(total_proposed, 1)
avg_latency = total_time / max(tokens_generated / k, 1)
return acceptance_rate, avg_latency
def _generate_draft_tokens(
self,
model,
input_ids: torch.Tensor,
k: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Generate K tokens from draft model."""
tokens = []
probs = []
current = input_ids.clone()
for _ in range(k):
outputs = model(current)
logits = outputs.logits[:, -1, :]
prob_dist = F.softmax(logits, dim=-1)
token = torch.argmax(prob_dist, dim=-1)
tokens.append(token.item())
probs.append(prob_dist[0, token].item())
current = torch.cat([current, token.unsqueeze(0)], dim=1)
return (
torch.tensor(tokens, device=self.device),
torch.tensor(probs, device=self.device)
)
def _get_target_probs(
self,
prefix: torch.Tensor,
draft_tokens: torch.Tensor
) -> torch.Tensor:
"""Get target model probabilities for draft tokens."""
full_input = torch.cat([prefix, draft_tokens.unsqueeze(0)], dim=1)
outputs = self.target(full_input)
logits = outputs.logits
prefix_len = prefix.shape[1]
relevant_logits = logits[:, prefix_len-1:prefix_len+len(draft_tokens), :]
probs = F.softmax(relevant_logits, dim=-1)
target_probs = []
for i, token in enumerate(draft_tokens):
target_probs.append(probs[0, i, token].item())
return torch.tensor(target_probs, device=self.device)
def _get_memory_usage(self, model) -> float:
"""Estimate model memory usage in MB."""
total_bytes = sum(
p.numel() * p.element_size()
for p in model.parameters()
)
return total_bytes / (1024 * 1024)
def recommend_draft_model(
self,
candidates: List,
eval_prompts: List[str],
target_speedup: float = 2.0
) -> Tuple[any, int]:
"""
Recommend best draft model and K value.
Args:
candidates: List of draft model candidates
eval_prompts: Prompts for evaluation
target_speedup: Desired speedup target
Returns:
Tuple of (best_model, best_k)
"""
best_model = None
best_k = None
best_efficiency = 0
for candidate in candidates:
results = self.evaluate_draft_model(candidate, eval_prompts)
for k, metrics in results.items():
# Efficiency = speedup / memory_cost
efficiency = metrics.theoretical_speedup / max(metrics.memory_mb, 1)
if metrics.theoretical_speedup >= target_speedup and efficiency > best_efficiency:
best_efficiency = efficiency
best_model = candidate
best_k = k
return best_model, best_k
def compare_draft_models(evaluator: DraftModelEvaluator, models: Dict[str, any], prompts: List[str]):
"""Compare multiple draft model candidates."""
import pandas as pd
all_results = []
for name, model in models.items():
print(f"\n{'='*50}")
print(f"Evaluating: {name}")
print("="*50)
results = evaluator.evaluate_draft_model(model, prompts, k_values=[4, 6, 8])
for k, metrics in results.items():
all_results.append({
"Model": name,
"K": k,
"Params (M)": metrics.parameters / 1e6,
"Accept Rate": f"{metrics.acceptance_rate:.2%}",
"Speedup": f"{metrics.theoretical_speedup:.2f}x",
"Memory (MB)": f"{metrics.memory_mb:.0f}"
})
df = pd.DataFrame(all_results)
print("\n" + "="*60)
print("COMPARISON RESULTS")
print("="*60)
print(df.to_string(index=False))
return df
# Example usage
if __name__ == "__main__":
print("Draft Model Selection Demo")
print("(Use with actual models for meaningful results)")Part 3: Production Implementation
A complete, optimized implementation for production use.
# production/speculative_engine.py
"""
Production-ready speculative decoding engine.
"""
import torch
import torch.nn.functional as F
from typing import Optional, List, Generator
from dataclasses import dataclass
import time
from threading import Thread
from queue import Queue
@dataclass
class GenerationConfig:
"""Configuration for text generation."""
max_new_tokens: int = 256
num_speculative_tokens: int = 5
temperature: float = 0.7
top_p: float = 0.9
top_k: int = 50
repetition_penalty: float = 1.1
do_sample: bool = True
@dataclass
class GenerationResult:
"""Result of text generation."""
text: str
tokens_generated: int
total_time_s: float
tokens_per_second: float
acceptance_rate: float
draft_efficiency: float
class SpeculativeEngine:
"""
Production speculative decoding engine.
Features:
- KV cache management for both models
- Batched draft generation
- Adaptive K selection
- Memory-efficient implementation
"""
def __init__(
self,
target_model,
draft_model,
tokenizer,
device: str = "cuda"
):
self.target = target_model.to(device)
self.draft = draft_model.to(device)
self.tokenizer = tokenizer
self.device = device
# Enable evaluation mode
self.target.eval()
self.draft.eval()
# KV cache will be managed per generation
self._draft_cache = None
self._target_cache = None
@torch.no_grad()
def generate(
self,
prompt: str,
config: GenerationConfig = None
) -> GenerationResult:
"""
Generate text using speculative decoding.
Args:
prompt: Input text prompt
config: Generation configuration
Returns:
GenerationResult with generated text and metrics
"""
config = config or GenerationConfig()
# Encode prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
# Initialize caches
self._init_caches(input_ids)
# Tracking
start_time = time.time()
tokens_generated = 0
accepted_tokens = 0
total_proposed = 0
generated_ids = input_ids.clone()
while tokens_generated < config.max_new_tokens:
# Adaptive K based on recent acceptance rate
k = self._get_adaptive_k(config, accepted_tokens, total_proposed)
# Step 1: Draft generation with cache
draft_tokens, draft_probs, draft_cache = self._cached_draft_generate(
generated_ids, k, config
)
# Step 2: Target verification (parallel)
target_probs, target_cache = self._cached_target_verify(
generated_ids, draft_tokens
)
# Step 3: Accept/reject
accepted, new_token, n_accepted = self._verify_tokens(
draft_tokens, draft_probs, target_probs, config
)
# Update tracking
accepted_tokens += n_accepted
total_proposed += k
tokens_generated += n_accepted + 1
# Update generated sequence
if n_accepted > 0:
generated_ids = torch.cat([generated_ids, accepted.unsqueeze(0)], dim=1)
generated_ids = torch.cat([generated_ids, new_token.unsqueeze(0).unsqueeze(0)], dim=1)
# Update caches (truncate to accepted length)
self._update_caches(draft_cache, target_cache, n_accepted)
# Check for EOS
if new_token.item() == self.tokenizer.eos_token_id:
break
total_time = time.time() - start_time
return GenerationResult(
text=self.tokenizer.decode(generated_ids[0], skip_special_tokens=True),
tokens_generated=tokens_generated,
total_time_s=total_time,
tokens_per_second=tokens_generated / total_time if total_time > 0 else 0,
acceptance_rate=accepted_tokens / max(total_proposed, 1),
draft_efficiency=accepted_tokens / max(total_proposed, 1) * config.num_speculative_tokens
)
def generate_stream(
self,
prompt: str,
config: GenerationConfig = None
) -> Generator[str, None, None]:
"""
Stream generated tokens.
Yields:
Individual tokens as they're generated
"""
config = config or GenerationConfig()
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
self._init_caches(input_ids)
generated_ids = input_ids.clone()
tokens_generated = 0
accepted_tokens = 0
total_proposed = 0
# Yield prompt
yield prompt
while tokens_generated < config.max_new_tokens:
k = self._get_adaptive_k(config, accepted_tokens, total_proposed)
draft_tokens, draft_probs, draft_cache = self._cached_draft_generate(
generated_ids, k, config
)
target_probs, target_cache = self._cached_target_verify(
generated_ids, draft_tokens
)
accepted, new_token, n_accepted = self._verify_tokens(
draft_tokens, draft_probs, target_probs, config
)
accepted_tokens += n_accepted
total_proposed += k
tokens_generated += n_accepted + 1
# Yield accepted tokens
if n_accepted > 0:
for token_id in accepted:
token_text = self.tokenizer.decode([token_id.item()])
yield token_text
generated_ids = torch.cat([generated_ids, accepted.unsqueeze(0)], dim=1)
# Yield new token
new_token_text = self.tokenizer.decode([new_token.item()])
yield new_token_text
generated_ids = torch.cat([generated_ids, new_token.unsqueeze(0).unsqueeze(0)], dim=1)
self._update_caches(draft_cache, target_cache, n_accepted)
if new_token.item() == self.tokenizer.eos_token_id:
break
def _init_caches(self, input_ids: torch.Tensor):
"""Initialize KV caches for both models."""
# Run initial forward pass to build cache
self._draft_cache = None
self._target_cache = None
def _get_adaptive_k(
self,
config: GenerationConfig,
accepted: int,
total: int
) -> int:
"""Adaptively adjust K based on acceptance rate."""
if total < 10:
return config.num_speculative_tokens
acceptance_rate = accepted / total
# Increase K if acceptance is high, decrease if low
if acceptance_rate > 0.8:
return min(config.num_speculative_tokens + 2, 12)
elif acceptance_rate < 0.5:
return max(config.num_speculative_tokens - 2, 2)
else:
return config.num_speculative_tokens
def _cached_draft_generate(
self,
input_ids: torch.Tensor,
k: int,
config: GenerationConfig
):
"""Generate K draft tokens with KV cache."""
draft_tokens = []
draft_probs = []
current_ids = input_ids.clone()
cache = self._draft_cache
for _ in range(k):
outputs = self.draft(
current_ids if cache is None else current_ids[:, -1:],
past_key_values=cache,
use_cache=True
)
logits = outputs.logits[:, -1, :]
cache = outputs.past_key_values
# Apply sampling
probs = self._apply_sampling(logits, config)
if config.do_sample:
token = torch.multinomial(probs, num_samples=1)
else:
token = torch.argmax(probs, dim=-1, keepdim=True)
draft_tokens.append(token.item())
draft_probs.append(probs[0, token.item()].item())
current_ids = torch.cat([current_ids, token], dim=1)
return (
torch.tensor(draft_tokens, device=self.device),
torch.tensor(draft_probs, device=self.device),
cache
)
def _cached_target_verify(
self,
prefix: torch.Tensor,
draft_tokens: torch.Tensor
):
"""Verify draft tokens with target model using cache."""
full_input = torch.cat([prefix, draft_tokens.unsqueeze(0)], dim=1)
# Use cache if available
if self._target_cache is not None:
# Only need to process new tokens
input_to_process = full_input[:, -len(draft_tokens)-1:]
outputs = self.target(
input_to_process,
past_key_values=self._target_cache,
use_cache=True
)
else:
outputs = self.target(full_input, use_cache=True)
logits = outputs.logits
cache = outputs.past_key_values
# Extract relevant probabilities
probs = F.softmax(logits[:, -len(draft_tokens)-1:, :], dim=-1)
target_probs = []
for i, token in enumerate(draft_tokens):
target_probs.append(probs[0, i, token].item())
# Store last distribution for resampling
self._last_target_dist = probs[0, -1, :]
return torch.tensor(target_probs, device=self.device), cache
def _verify_tokens(
self,
draft_tokens: torch.Tensor,
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
config: GenerationConfig
):
"""Verify and accept/reject draft tokens."""
accepted = []
n_accepted = 0
for i, (token, d_prob, t_prob) in enumerate(
zip(draft_tokens, draft_probs, target_probs)
):
# Rejection sampling
r = torch.rand(1, device=self.device).item()
acceptance_prob = min(1.0, t_prob.item() / max(d_prob.item(), 1e-10))
if r < acceptance_prob:
accepted.append(token)
n_accepted += 1
else:
break
accepted_tensor = torch.tensor(
[t.item() for t in accepted], device=self.device
) if accepted else torch.tensor([], device=self.device)
# Resample from adjusted distribution
adjusted_probs = self._apply_sampling(
self._last_target_dist.unsqueeze(0), config
)
if config.do_sample:
new_token = torch.multinomial(adjusted_probs, num_samples=1).squeeze()
else:
new_token = torch.argmax(adjusted_probs).squeeze()
return accepted_tensor, new_token, n_accepted
def _apply_sampling(
self,
logits: torch.Tensor,
config: GenerationConfig
) -> torch.Tensor:
"""Apply temperature, top-k, top-p sampling."""
# Temperature
if config.temperature != 1.0:
logits = logits / config.temperature
# Top-k
if config.top_k > 0:
top_k_values = torch.topk(logits, config.top_k)[0]
min_top_k = top_k_values[..., -1, None]
logits = torch.where(logits < min_top_k, float("-inf"), logits)
# Top-p
if config.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > config.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
return F.softmax(logits, dim=-1)
def _update_caches(self, draft_cache, target_cache, n_accepted: int):
"""Update caches after acceptance/rejection."""
# In a full implementation, we'd truncate caches to accepted length
# For simplicity, we reset caches (production would optimize this)
self._draft_cache = None
self._target_cache = None
# FastAPI Server for production
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
import asyncio
app_state = {}
class GenerateRequest(BaseModel):
prompt: str
max_tokens: int = 256
temperature: float = 0.7
stream: bool = False
class GenerateResponse(BaseModel):
text: str
tokens_generated: int
tokens_per_second: float
acceptance_rate: float
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialize models on startup
print("Loading models...")
# In production, load actual models here
app_state["engine"] = None # Placeholder
yield
app_state.clear()
app = FastAPI(
title="Speculative Decoding API",
description="Fast LLM inference using speculative decoding",
version="1.0.0",
lifespan=lifespan
)
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""Generate text using speculative decoding."""
engine = app_state.get("engine")
if engine is None:
raise HTTPException(status_code=503, detail="Models not loaded")
config = GenerationConfig(
max_new_tokens=request.max_tokens,
temperature=request.temperature
)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None, engine.generate, request.prompt, config
)
return GenerateResponse(
text=result.text,
tokens_generated=result.tokens_generated,
tokens_per_second=result.tokens_per_second,
acceptance_rate=result.acceptance_rate
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)★ Insight ─────────────────────────────────────
Adaptive K Selection: The optimal K depends on acceptance rate. High acceptance (>80%) means we can speculatively generate more tokens. Low acceptance (<50%) means we're wasting compute on rejected drafts. Adaptive K maximizes throughput across different prompt types and model pairs.
─────────────────────────────────────────────────
Part 4: Benchmarking and Optimization
Comprehensive performance measurement.
# benchmark/benchmark.py
"""
Benchmarking suite for speculative decoding.
"""
import torch
import time
import json
from typing import List, Dict
from dataclasses import dataclass, asdict
import numpy as np
from tqdm import tqdm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
@dataclass
class BenchmarkConfig:
"""Configuration for benchmarking."""
num_prompts: int = 50
tokens_per_prompt: int = 100
warmup_runs: int = 3
k_values: List[int] = None
def __post_init__(self):
if self.k_values is None:
self.k_values = [3, 5, 7, 10]
@dataclass
class BenchmarkResult:
"""Results from a benchmark run."""
method: str
k_value: int
avg_tokens_per_second: float
std_tokens_per_second: float
avg_acceptance_rate: float
avg_latency_ms: float
speedup_vs_baseline: float
memory_mb: float
class SpeculativeBenchmark:
"""Benchmark speculative decoding performance."""
def __init__(
self,
target_model,
draft_model,
tokenizer,
device: str = "cuda"
):
self.target = target_model.to(device)
self.draft = draft_model.to(device)
self.tokenizer = tokenizer
self.device = device
self.target.eval()
self.draft.eval()
def run_benchmark(
self,
prompts: List[str],
config: BenchmarkConfig = None
) -> Dict[str, List[BenchmarkResult]]:
"""
Run comprehensive benchmark.
Args:
prompts: Test prompts
config: Benchmark configuration
Returns:
Dictionary of results
"""
config = config or BenchmarkConfig()
results = {
"baseline": [],
"speculative": []
}
# Warmup
print("Warming up...")
self._warmup(prompts[:config.warmup_runs])
# Baseline (autoregressive)
print("\nRunning baseline (autoregressive)...")
baseline_result = self._benchmark_baseline(prompts, config)
results["baseline"].append(baseline_result)
# Speculative decoding with different K values
for k in config.k_values:
print(f"\nRunning speculative decoding (K={k})...")
spec_result = self._benchmark_speculative(prompts, config, k, baseline_result.avg_tokens_per_second)
results["speculative"].append(spec_result)
return results
def _warmup(self, prompts: List[str]):
"""Warmup models."""
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
self.target.generate(input_ids, max_new_tokens=10)
@torch.no_grad()
def _benchmark_baseline(
self,
prompts: List[str],
config: BenchmarkConfig
) -> BenchmarkResult:
"""Benchmark baseline autoregressive generation."""
all_tps = []
all_latencies = []
for prompt in tqdm(prompts[:config.num_prompts]):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
start_time = time.time()
outputs = self.target.generate(
input_ids,
max_new_tokens=config.tokens_per_prompt,
do_sample=True,
temperature=0.7
)
elapsed = time.time() - start_time
tokens = outputs.shape[1] - input_ids.shape[1]
tps = tokens / elapsed
latency = elapsed * 1000 / tokens # ms per token
all_tps.append(tps)
all_latencies.append(latency)
return BenchmarkResult(
method="baseline",
k_value=1,
avg_tokens_per_second=np.mean(all_tps),
std_tokens_per_second=np.std(all_tps),
avg_acceptance_rate=1.0,
avg_latency_ms=np.mean(all_latencies),
speedup_vs_baseline=1.0,
memory_mb=self._get_memory_usage(self.target)
)
@torch.no_grad()
def _benchmark_speculative(
self,
prompts: List[str],
config: BenchmarkConfig,
k: int,
baseline_tps: float
) -> BenchmarkResult:
"""Benchmark speculative decoding."""
all_tps = []
all_latencies = []
all_acceptance = []
for prompt in tqdm(prompts[:config.num_prompts]):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
start_time = time.time()
output, metrics = self._speculative_generate(input_ids, config.tokens_per_prompt, k)
elapsed = time.time() - start_time
tokens = output.shape[1] - input_ids.shape[1]
tps = tokens / elapsed
latency = elapsed * 1000 / tokens
all_tps.append(tps)
all_latencies.append(latency)
all_acceptance.append(metrics["acceptance_rate"])
return BenchmarkResult(
method="speculative",
k_value=k,
avg_tokens_per_second=np.mean(all_tps),
std_tokens_per_second=np.std(all_tps),
avg_acceptance_rate=np.mean(all_acceptance),
avg_latency_ms=np.mean(all_latencies),
speedup_vs_baseline=np.mean(all_tps) / baseline_tps if baseline_tps > 0 else 0,
memory_mb=self._get_memory_usage(self.target) + self._get_memory_usage(self.draft)
)
def _speculative_generate(
self,
input_ids: torch.Tensor,
max_tokens: int,
k: int
) -> tuple:
"""Simple speculative generation for benchmarking."""
generated = input_ids.clone()
tokens_generated = 0
accepted_total = 0
proposed_total = 0
while tokens_generated < max_tokens:
# Draft K tokens
current = generated.clone()
draft_tokens = []
draft_probs = []
for _ in range(k):
out = self.draft(current)
logits = out.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
token = torch.multinomial(probs, 1)
draft_tokens.append(token.item())
draft_probs.append(probs[0, token.item()].item())
current = torch.cat([current, token], dim=1)
# Verify with target
draft_tensor = torch.tensor(draft_tokens, device=self.device)
full_input = torch.cat([generated, draft_tensor.unsqueeze(0)], dim=1)
target_out = self.target(full_input)
target_logits = target_out.logits[:, generated.shape[1]-1:generated.shape[1]+k, :]
target_probs = torch.softmax(target_logits, dim=-1)
# Accept/reject
accepted = []
for i, (token, d_prob) in enumerate(zip(draft_tokens, draft_probs)):
t_prob = target_probs[0, i, token].item()
if torch.rand(1).item() < min(1.0, t_prob / max(d_prob, 1e-10)):
accepted.append(token)
accepted_total += 1
else:
break
proposed_total += 1
# Add accepted tokens
if accepted:
accepted_tensor = torch.tensor(accepted, device=self.device).unsqueeze(0)
generated = torch.cat([generated, accepted_tensor], dim=1)
# Sample new token
last_probs = target_probs[0, len(accepted), :]
new_token = torch.multinomial(last_probs.unsqueeze(0), 1)
generated = torch.cat([generated, new_token], dim=1)
tokens_generated += len(accepted) + 1
proposed_total += 1
metrics = {
"acceptance_rate": accepted_total / max(proposed_total, 1)
}
return generated, metrics
def _get_memory_usage(self, model) -> float:
"""Get model memory usage in MB."""
return sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
def plot_results(self, results: Dict[str, List[BenchmarkResult]], save_path: str = None):
"""Create visualization of benchmark results."""
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
"Tokens per Second",
"Speedup vs Baseline",
"Acceptance Rate",
"Latency (ms/token)"
)
)
# Extract data
baseline = results["baseline"][0]
spec_results = results["speculative"]
k_values = [r.k_value for r in spec_results]
tps_values = [r.avg_tokens_per_second for r in spec_results]
speedups = [r.speedup_vs_baseline for r in spec_results]
acceptance = [r.avg_acceptance_rate for r in spec_results]
latencies = [r.avg_latency_ms for r in spec_results]
# Tokens per second
fig.add_trace(
go.Bar(x=["Baseline"] + [f"K={k}" for k in k_values],
y=[baseline.avg_tokens_per_second] + tps_values,
name="TPS"),
row=1, col=1
)
# Speedup
fig.add_trace(
go.Bar(x=[f"K={k}" for k in k_values],
y=speedups,
name="Speedup"),
row=1, col=2
)
fig.add_hline(y=1.0, line_dash="dash", row=1, col=2)
# Acceptance rate
fig.add_trace(
go.Bar(x=[f"K={k}" for k in k_values],
y=acceptance,
name="Acceptance"),
row=2, col=1
)
# Latency
fig.add_trace(
go.Bar(x=["Baseline"] + [f"K={k}" for k in k_values],
y=[baseline.avg_latency_ms] + latencies,
name="Latency"),
row=2, col=2
)
fig.update_layout(
height=700,
title="Speculative Decoding Benchmark Results",
showlegend=False
)
if save_path:
fig.write_html(save_path)
print(f"Results saved to {save_path}")
return fig
# Example usage
if __name__ == "__main__":
# Test prompts
TEST_PROMPTS = [
"Explain the theory of relativity in simple terms.",
"Write a Python function to calculate fibonacci numbers.",
"What are the key principles of machine learning?",
"Describe the process of photosynthesis.",
"How does the internet work?",
]
print("Speculative Decoding Benchmark")
print("(Run with actual models for meaningful results)")Docker Configuration
# docker-compose.yml
version: '3.8'
services:
speculative-api:
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
volumes:
- ./models:/models
environment:
- TARGET_MODEL=/models/llama-7b
- DRAFT_MODEL=/models/tinyllama-1b
- CUDA_VISIBLE_DEVICES=0
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
restart: unless-stopped# Dockerfile
FROM nvidia/cuda:12.1-runtime-ubuntu22.04
WORKDIR /app
# Install Python
RUN apt-get update && apt-get install -y python3 python3-pip
# Install dependencies
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Copy application
COPY . .
EXPOSE 8000
CMD ["python3", "-m", "uvicorn", "production.speculative_engine:app", "--host", "0.0.0.0", "--port", "8000"]Exercises
Exercise 1: Draft Model Training
Train a small draft model specifically for speculative decoding:
- Use distillation from the target model
- Optimize for high acceptance rate rather than perplexity
- Compare with off-the-shelf small models
Exercise 2: Multi-Draft Speculative Decoding
Implement speculative decoding with multiple draft models:
- Use ensemble of drafts for higher acceptance
- Implement tree-based speculation
- Measure tradeoffs vs single draft
Exercise 3: Hardware Optimization
Optimize for specific hardware:
- Profile memory bandwidth utilization
- Implement custom CUDA kernels for verification
- Compare CPU vs GPU for draft model
Exercise 4: Production Deployment
Build a production-ready system:
- Implement request batching with speculation
- Add monitoring and alerting
- Handle graceful degradation when draft quality drops
Summary
You've learned to implement speculative decoding:
- Core Algorithm: Draft-verify-accept/reject cycle
- Draft Selection: Choosing optimal draft models
- Production Implementation: KV caching, adaptive K
- Benchmarking: Measuring speedup and acceptance
Key insights:
- Speculative decoding exploits memory-boundedness of LLM inference
- Speedup depends on K × acceptance_rate
- Draft model should be fast with high acceptance, not necessarily high quality
- Adaptive K maximizes throughput across different contexts
- KV cache management is critical for efficiency
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Speculative Decoding | Draft model generates K tokens, target verifies in one pass | 1.5-3x speedup while maintaining exact output distribution |
| Memory-Bound Inference | LLM inference limited by memory bandwidth, not compute | Verifying K tokens costs ~same as generating 1 token |
| Draft Model | Small, fast model that approximates target distribution | Higher acceptance rate = more speedup; quality matters less than speed |
| Target Model | Large model whose output distribution we want to match | Single forward pass verifies all K draft tokens in parallel |
| Acceptance Rate | Fraction of draft tokens accepted by target model | Speedup ≈ K × acceptance_rate; key metric to optimize |
| Rejection Sampling | Accept token if random r < min(1, p_target / p_draft) | Mathematically guarantees exact target distribution |
| Adaptive K | Dynamically adjust K based on recent acceptance rate | High acceptance → increase K; low acceptance → decrease K |
| KV Cache Management | Maintain separate caches for draft and target models | Must truncate draft cache on rejection to stay consistent |
| Model Pairing | Match draft/target from same family (Llama 8B/70B) | Same tokenizer and similar distribution improve acceptance |
| Theoretical Speedup | (K × acceptance + 1) / (K × t_draft/t_target + 1) | Approaches K × acceptance when draft is much faster |
Next Steps
- Production SLM System - Scale speculative decoding
- Training SLM from Scratch - Build custom draft models