Quantization Pipeline
Optimize models for faster inference
Quantization Pipeline
TL;DR
Reduce model precision from FP32 (4 bytes) to INT8 (1 byte) or INT4 for 4x size reduction and 2-4x speedup with minimal accuracy loss. GPTQ uses second-order information for better 4-bit accuracy. AWQ preserves "salient" weights based on activation patterns. BitsAndBytes makes 4-bit loading trivial.
Overview
| Difficulty | Intermediate |
| Time | ~4 hours |
| Prerequisites | PyTorch basics, model inference |
| Learning Outcomes | INT8 quantization, GPTQ, AWQ, benchmarking |
Why Reduce Precision?
A 7B parameter model in FP32 requires 28GB of memory just for weights, making it impossible to serve on most GPUs. Quantization to INT8 cuts this to 7GB; INT4 cuts it to 3.5GB. Beyond memory savings, quantized models run faster because modern GPUs have dedicated INT8 compute units, and reduced memory bandwidth means data moves to the compute cores faster. The trade-off is minimal: well-calibrated INT8 models typically lose less than 0.5% accuracy.
Introduction
Quantization reduces the numerical precision of model weights and activations, typically from FP32 to INT8 or lower. This provides significant benefits for deployment.
Quantization Benefits
FP32 Model
Weight: 4 bytes, Model: 440 MB, Speed: 100 tokens/sec
INT8 Model (Quantized)
RecommendedWeight: 1 byte, Model: 110 MB, Speed: 300 tokens/sec — 4x smaller, 3x faster, ~0.5% accuracy loss
Benefits of Quantization
| Metric | FP32 | INT8 | Improvement |
|---|---|---|---|
| Model Size | 4x | 1x | 4x smaller |
| Memory Bandwidth | 4x | 1x | 4x faster |
| Compute (INT8 cores) | 1x | 2-4x | 2-4x faster |
| Accuracy | 100% | 99.5%+ | Minimal loss |
Quantization Types
Quantization Types
Post-Training Quantization
Quantization-Aware Training (QAT)
Dynamic vs Static Quantization
| Type | When | Pros | Cons |
|---|---|---|---|
| Dynamic | Runtime | No calibration needed | Higher latency |
| Static | Pre-computed | Faster inference | Requires calibration data |
Project Setup
# Create project directory
mkdir quantization-pipeline && cd quantization-pipeline
# Create virtual environment
python -m venv venv
source venv/bin/activate
# Install dependencies
pip install torch transformers datasets accelerate
pip install bitsandbytes auto-gptq autoawq
pip install optimum onnx onnxruntimeProject Structure
quantization-pipeline/
├── quantization/
│ ├── dynamic.py # Dynamic quantization
│ ├── static.py # Static quantization
│ ├── gptq.py # GPTQ for LLMs
│ └── awq.py # AWQ implementation
├── calibration/
│ └── dataset.py # Calibration data
├── benchmarks/
│ ├── latency.py # Latency benchmarks
│ └── accuracy.py # Accuracy evaluation
├── scripts/
│ ├── quantize.py # Quantization script
│ └── benchmark.py # Benchmarking
└── requirements.txtDynamic Quantization
Dynamic quantization quantizes weights ahead of time but computes activation scales at runtime.
# quantization/dynamic.py
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from typing import Dict, Any
import time
class DynamicQuantizer:
"""Apply dynamic quantization to models."""
def __init__(self, model: nn.Module):
self.original_model = model
self.quantized_model = None
def quantize(
self,
dtype: torch.dtype = torch.qint8,
modules_to_quantize: set = None,
) -> nn.Module:
"""
Apply dynamic quantization.
Args:
dtype: Target dtype (qint8 or float16)
modules_to_quantize: Specific module types to quantize
"""
if modules_to_quantize is None:
modules_to_quantize = {nn.Linear, nn.LSTM, nn.GRU}
self.quantized_model = torch.quantization.quantize_dynamic(
self.original_model,
modules_to_quantize,
dtype=dtype,
)
return self.quantized_model
def compare_sizes(self) -> Dict[str, float]:
"""Compare model sizes before and after quantization."""
def get_size_mb(model: nn.Module) -> float:
param_size = sum(
p.nelement() * p.element_size()
for p in model.parameters()
)
buffer_size = sum(
b.nelement() * b.element_size()
for b in model.buffers()
)
return (param_size + buffer_size) / (1024 * 1024)
original_size = get_size_mb(self.original_model)
quantized_size = get_size_mb(self.quantized_model) if self.quantized_model else 0
return {
"original_mb": original_size,
"quantized_mb": quantized_size,
"compression_ratio": original_size / quantized_size if quantized_size else 0,
}
def quantize_transformer_dynamic(
model_name: str,
save_path: str = None,
) -> nn.Module:
"""Quantize a HuggingFace transformer model."""
print(f"Loading model: {model_name}")
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
# Apply dynamic quantization to Linear layers
quantized_model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8,
)
if save_path:
torch.save(quantized_model.state_dict(), save_path)
return quantized_model
# Example usage
if __name__ == "__main__":
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
quantizer = DynamicQuantizer(model)
quantized = quantizer.quantize()
sizes = quantizer.compare_sizes()
print(f"Original size: {sizes['original_mb']:.2f} MB")
print(f"Quantized size: {sizes['quantized_mb']:.2f} MB")
print(f"Compression: {sizes['compression_ratio']:.2f}x")Static Quantization
Static quantization requires calibration data to determine optimal quantization ranges for activations.
# quantization/static.py
import torch
import torch.nn as nn
from torch.quantization import (
get_default_qconfig,
prepare,
convert,
QConfig,
)
from torch.quantization.observer import (
MinMaxObserver,
PerChannelMinMaxObserver,
HistogramObserver,
)
from typing import List, Callable
from tqdm import tqdm
class StaticQuantizer:
"""Static quantization with calibration."""
def __init__(
self,
model: nn.Module,
backend: str = "fbgemm", # or "qnnpack" for ARM
):
self.model = model
self.backend = backend
self.calibrated = False
def prepare(
self,
qconfig: QConfig = None,
) -> nn.Module:
"""Prepare model for quantization."""
self.model.eval()
# Set quantization backend
torch.backends.quantized.engine = self.backend
# Set qconfig
if qconfig is None:
qconfig = get_default_qconfig(self.backend)
self.model.qconfig = qconfig
# Fuse modules for better performance
self.model = self._fuse_modules(self.model)
# Insert observers
self.prepared_model = prepare(self.model)
return self.prepared_model
def calibrate(
self,
calibration_loader,
num_batches: int = 100,
):
"""Run calibration with representative data."""
self.prepared_model.eval()
with torch.no_grad():
for i, batch in enumerate(tqdm(calibration_loader, desc="Calibrating")):
if i >= num_batches:
break
# Handle different batch formats
if isinstance(batch, dict):
inputs = {k: v for k, v in batch.items() if k != "labels"}
else:
inputs = batch[0]
self.prepared_model(**inputs) if isinstance(inputs, dict) else self.prepared_model(inputs)
self.calibrated = True
def convert(self) -> nn.Module:
"""Convert to quantized model."""
if not self.calibrated:
raise RuntimeError("Model must be calibrated before conversion")
self.quantized_model = convert(self.prepared_model)
return self.quantized_model
def _fuse_modules(self, model: nn.Module) -> nn.Module:
"""Fuse Conv-BN-ReLU and Linear-ReLU modules."""
# Find and fuse common patterns
modules_to_fuse = []
for name, module in model.named_modules():
if isinstance(module, nn.Sequential):
# Look for fusable patterns
children = list(module.named_children())
for i in range(len(children) - 1):
curr_name, curr_mod = children[i]
next_name, next_mod = children[i + 1]
if isinstance(curr_mod, nn.Linear) and isinstance(next_mod, nn.ReLU):
modules_to_fuse.append([f"{name}.{curr_name}", f"{name}.{next_name}"])
if modules_to_fuse:
model = torch.quantization.fuse_modules(model, modules_to_fuse)
return model
class QuantizationObservers:
"""Different observer configurations for quantization."""
@staticmethod
def minmax_qconfig() -> QConfig:
"""MinMax observer - simple but sensitive to outliers."""
return QConfig(
activation=MinMaxObserver.with_args(dtype=torch.quint8),
weight=MinMaxObserver.with_args(dtype=torch.qint8),
)
@staticmethod
def histogram_qconfig() -> QConfig:
"""Histogram observer - better handling of outliers."""
return QConfig(
activation=HistogramObserver.with_args(dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(dtype=torch.qint8),
)
@staticmethod
def per_channel_qconfig() -> QConfig:
"""Per-channel quantization - better accuracy."""
return QConfig(
activation=MinMaxObserver.with_args(dtype=torch.quint8),
weight=PerChannelMinMaxObserver.with_args(
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
),
)GPTQ for Large Language Models
GPTQ is a one-shot weight quantization method that uses second-order information for better accuracy.
# quantization/gptq.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from typing import List, Optional
from datasets import load_dataset
class GPTQQuantizer:
"""GPTQ quantization for LLMs."""
def __init__(
self,
model_name: str,
bits: int = 4,
group_size: int = 128,
desc_act: bool = True,
):
self.model_name = model_name
self.bits = bits
self.group_size = group_size
self.desc_act = desc_act
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def prepare_calibration_data(
self,
dataset_name: str = "c4",
num_samples: int = 128,
seq_length: int = 2048,
) -> List[str]:
"""Prepare calibration dataset."""
dataset = load_dataset(dataset_name, "en", split="train", streaming=True)
samples = []
for i, sample in enumerate(dataset):
if i >= num_samples:
break
samples.append(sample["text"][:seq_length])
return samples
def quantize(
self,
calibration_data: List[str],
output_dir: str,
):
"""Quantize model with GPTQ."""
# Create quantization config
quantize_config = BaseQuantizeConfig(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
damp_percent=0.01,
)
# Load model
print(f"Loading model: {self.model_name}")
model = AutoGPTQForCausalLM.from_pretrained(
self.model_name,
quantize_config=quantize_config,
)
# Prepare examples
examples = [
self.tokenizer(
text,
return_tensors="pt",
max_length=2048,
truncation=True,
padding=True,
)
for text in calibration_data
]
# Quantize
print("Quantizing model...")
model.quantize(examples)
# Save
print(f"Saving to {output_dir}")
model.save_quantized(output_dir)
self.tokenizer.save_pretrained(output_dir)
return model
def load_gptq_model(
quantized_path: str,
device: str = "cuda",
):
"""Load a GPTQ quantized model."""
model = AutoGPTQForCausalLM.from_quantized(
quantized_path,
device=device,
use_safetensors=True,
inject_fused_attention=True,
inject_fused_mlp=True,
)
tokenizer = AutoTokenizer.from_pretrained(quantized_path)
return model, tokenizer
# Example usage
if __name__ == "__main__":
quantizer = GPTQQuantizer(
model_name="meta-llama/Llama-3.1-8B",
bits=4,
group_size=128,
)
# Prepare calibration data
calibration_data = quantizer.prepare_calibration_data(
num_samples=128,
)
# Quantize
quantizer.quantize(
calibration_data=calibration_data,
output_dir="./llama-3.1-8b-gptq",
)AWQ (Activation-aware Weight Quantization)
AWQ preserves salient weights based on activation magnitudes.
# quantization/awq.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from awq import AutoAWQForCausalLM
from typing import List, Optional
class AWQQuantizer:
"""AWQ quantization for LLMs."""
def __init__(
self,
model_name: str,
w_bit: int = 4,
q_group_size: int = 128,
zero_point: bool = True,
):
self.model_name = model_name
self.w_bit = w_bit
self.q_group_size = q_group_size
self.zero_point = zero_point
def quantize(
self,
output_dir: str,
calib_data: str = "pileval",
n_samples: int = 128,
seq_len: int = 512,
):
"""Quantize model with AWQ."""
# Load model
model = AutoAWQForCausalLM.from_pretrained(
self.model_name,
safetensors=True,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# Quantization config
quant_config = {
"zero_point": self.zero_point,
"q_group_size": self.q_group_size,
"w_bit": self.w_bit,
}
# Quantize
print("Running AWQ quantization...")
model.quantize(
tokenizer,
quant_config=quant_config,
calib_data=calib_data,
n_samples=n_samples,
seqlen=seq_len,
)
# Save
model.save_quantized(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Saved quantized model to {output_dir}")
return model
def load_awq_model(
quantized_path: str,
device: str = "cuda",
fuse_layers: bool = True,
):
"""Load an AWQ quantized model."""
model = AutoAWQForCausalLM.from_quantized(
quantized_path,
fuse_layers=fuse_layers,
)
tokenizer = AutoTokenizer.from_pretrained(quantized_path)
return model, tokenizer
# Comparison of GPTQ vs AWQ
class QuantizationComparison:
"""Compare different quantization methods."""
@staticmethod
def get_comparison_table():
"""Return comparison of quantization methods."""
return """
| Method | Bits | Speed | Accuracy | Memory |
|--------|------|-------|----------|--------|
| FP16 | 16 | 1x | 100% | 2x |
| INT8 | 8 | 2x | 99.5% | 1x |
| GPTQ | 4 | 3x | 99% | 0.5x |
| AWQ | 4 | 3.5x | 99.2% | 0.5x |
| GGML | 4 | 2.5x | 98.5% | 0.5x |
"""BitsAndBytes Integration
BitsAndBytes provides easy 8-bit and 4-bit quantization for HuggingFace models.
# quantization/bitsandbytes_quant.py
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
from typing import Optional
class BnBQuantizer:
"""BitsAndBytes quantization wrapper."""
@staticmethod
def load_8bit(
model_name: str,
device_map: str = "auto",
):
"""Load model in 8-bit precision."""
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
@staticmethod
def load_4bit(
model_name: str,
device_map: str = "auto",
compute_dtype: torch.dtype = torch.bfloat16,
quant_type: str = "nf4", # or "fp4"
use_double_quant: bool = True,
):
"""Load model in 4-bit precision (QLoRA style)."""
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=quant_type,
bnb_4bit_use_double_quant=use_double_quant,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
@staticmethod
def get_memory_usage(model) -> dict:
"""Get GPU memory usage."""
if not torch.cuda.is_available():
return {"error": "CUDA not available"}
allocated = torch.cuda.memory_allocated() / (1024 ** 3)
reserved = torch.cuda.memory_reserved() / (1024 ** 3)
return {
"allocated_gb": allocated,
"reserved_gb": reserved,
}
# Example usage comparing precision
def compare_precision(model_name: str):
"""Compare model at different precisions."""
results = {}
# FP16
print("Loading FP16...")
model_fp16 = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
)
results["fp16"] = BnBQuantizer.get_memory_usage(model_fp16)
del model_fp16
torch.cuda.empty_cache()
# 8-bit
print("Loading 8-bit...")
model_8bit, _ = BnBQuantizer.load_8bit(model_name)
results["8bit"] = BnBQuantizer.get_memory_usage(model_8bit)
del model_8bit
torch.cuda.empty_cache()
# 4-bit
print("Loading 4-bit...")
model_4bit, _ = BnBQuantizer.load_4bit(model_name)
results["4bit"] = BnBQuantizer.get_memory_usage(model_4bit)
del model_4bit
torch.cuda.empty_cache()
return resultsCalibration Dataset
# calibration/dataset.py
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizer
from datasets import load_dataset
from typing import List, Optional, Dict
import random
class CalibrationDataset(Dataset):
"""Dataset for quantization calibration."""
def __init__(
self,
texts: List[str],
tokenizer: PreTrainedTokenizer,
max_length: int = 512,
):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
text = self.texts[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
}
def create_calibration_loader(
dataset_name: str = "wikitext",
dataset_config: str = "wikitext-2-raw-v1",
tokenizer: PreTrainedTokenizer = None,
num_samples: int = 512,
batch_size: int = 8,
max_length: int = 512,
) -> DataLoader:
"""Create calibration data loader."""
# Load dataset
dataset = load_dataset(dataset_name, dataset_config, split="train")
# Sample texts
texts = []
for sample in dataset:
text = sample["text"].strip()
if len(text) > 100: # Skip very short texts
texts.append(text)
if len(texts) >= num_samples:
break
# Create dataset
calib_dataset = CalibrationDataset(
texts=texts,
tokenizer=tokenizer,
max_length=max_length,
)
# Create loader
loader = DataLoader(
calib_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
)
return loader
def create_task_specific_calibration(
task: str,
tokenizer: PreTrainedTokenizer,
num_samples: int = 256,
) -> List[str]:
"""Create task-specific calibration data."""
task_datasets = {
"classification": ("imdb", None, "text"),
"qa": ("squad", None, "context"),
"summarization": ("cnn_dailymail", "3.0.0", "article"),
"translation": ("wmt14", "de-en", "en"),
}
if task not in task_datasets:
raise ValueError(f"Unknown task: {task}")
dataset_name, config, text_field = task_datasets[task]
dataset = load_dataset(dataset_name, config, split="train")
texts = []
for sample in dataset:
text = sample[text_field] if isinstance(sample[text_field], str) else sample[text_field]["text"]
texts.append(text)
if len(texts) >= num_samples:
break
return textsBenchmarking
# benchmarks/latency.py
import torch
import time
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Optional
from tqdm import tqdm
class LatencyBenchmark:
"""Benchmark inference latency."""
def __init__(
self,
model,
tokenizer,
device: str = "cuda",
):
self.model = model
self.tokenizer = tokenizer
self.device = device
if hasattr(self.model, 'to'):
self.model = self.model.to(device)
self.model.eval()
def warmup(self, num_runs: int = 5):
"""Warmup the model."""
prompt = "Hello, world!"
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
for _ in range(num_runs):
_ = self.model.generate(
**inputs,
max_new_tokens=10,
pad_token_id=self.tokenizer.pad_token_id,
)
def benchmark_prefill(
self,
input_lengths: List[int] = [128, 256, 512, 1024],
num_runs: int = 10,
) -> Dict[int, float]:
"""Benchmark prefill (prompt processing) latency."""
results = {}
for length in input_lengths:
# Create input of specified length
input_ids = torch.randint(
0, self.tokenizer.vocab_size,
(1, length),
device=self.device,
)
latencies = []
for _ in range(num_runs):
torch.cuda.synchronize()
start = time.perf_counter()
with torch.no_grad():
_ = self.model(input_ids)
torch.cuda.synchronize()
latencies.append(time.perf_counter() - start)
results[length] = {
"mean_ms": np.mean(latencies) * 1000,
"std_ms": np.std(latencies) * 1000,
"tokens_per_sec": length / np.mean(latencies),
}
return results
def benchmark_generation(
self,
prompt: str,
output_lengths: List[int] = [32, 64, 128, 256],
num_runs: int = 5,
) -> Dict[int, float]:
"""Benchmark token generation latency."""
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
results = {}
for length in output_lengths:
latencies = []
tokens_generated = []
for _ in range(num_runs):
torch.cuda.synchronize()
start = time.perf_counter()
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=length,
do_sample=False,
pad_token_id=self.tokenizer.pad_token_id,
)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
latencies.append(elapsed)
tokens_generated.append(outputs.shape[1] - inputs["input_ids"].shape[1])
avg_tokens = np.mean(tokens_generated)
avg_latency = np.mean(latencies)
results[length] = {
"total_ms": avg_latency * 1000,
"tokens_per_sec": avg_tokens / avg_latency,
"ms_per_token": (avg_latency / avg_tokens) * 1000,
}
return results
def benchmark_throughput(
self,
batch_sizes: List[int] = [1, 2, 4, 8],
seq_length: int = 128,
num_runs: int = 10,
) -> Dict[int, float]:
"""Benchmark batch throughput."""
results = {}
for batch_size in batch_sizes:
input_ids = torch.randint(
0, self.tokenizer.vocab_size,
(batch_size, seq_length),
device=self.device,
)
latencies = []
for _ in range(num_runs):
torch.cuda.synchronize()
start = time.perf_counter()
with torch.no_grad():
_ = self.model(input_ids)
torch.cuda.synchronize()
latencies.append(time.perf_counter() - start)
avg_latency = np.mean(latencies)
total_tokens = batch_size * seq_length
results[batch_size] = {
"latency_ms": avg_latency * 1000,
"throughput_tokens_per_sec": total_tokens / avg_latency,
"samples_per_sec": batch_size / avg_latency,
}
return results
# benchmarks/accuracy.py
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from typing import Dict
from tqdm import tqdm
class AccuracyBenchmark:
"""Benchmark model accuracy after quantization."""
def __init__(
self,
model,
tokenizer,
device: str = "cuda",
):
self.model = model
self.tokenizer = tokenizer
self.device = device
def evaluate_perplexity(
self,
dataset_name: str = "wikitext",
dataset_config: str = "wikitext-2-raw-v1",
max_samples: int = 100,
max_length: int = 512,
) -> float:
"""Evaluate perplexity on dataset."""
dataset = load_dataset(dataset_name, dataset_config, split="test")
total_loss = 0
total_tokens = 0
for i, sample in enumerate(tqdm(dataset, desc="Evaluating")):
if i >= max_samples:
break
text = sample["text"].strip()
if not text:
continue
inputs = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=max_length,
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
total_loss += loss.item() * inputs["input_ids"].numel()
total_tokens += inputs["input_ids"].numel()
avg_loss = total_loss / total_tokens
perplexity = torch.exp(torch.tensor(avg_loss)).item()
return perplexity
def evaluate_classification(
self,
eval_dataset,
label_key: str = "label",
) -> Dict[str, float]:
"""Evaluate classification accuracy."""
correct = 0
total = 0
for sample in tqdm(eval_dataset, desc="Evaluating"):
inputs = self.tokenizer(
sample["text"],
return_tensors="pt",
truncation=True,
max_length=512,
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
prediction = outputs.logits.argmax(dim=-1).item()
if prediction == sample[label_key]:
correct += 1
total += 1
return {
"accuracy": correct / total,
"correct": correct,
"total": total,
}Quantization Script
# scripts/quantize.py
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from quantization.dynamic import DynamicQuantizer
from quantization.static import StaticQuantizer
from quantization.gptq import GPTQQuantizer
from quantization.awq import AWQQuantizer
from quantization.bitsandbytes_quant import BnBQuantizer
from calibration.dataset import create_calibration_loader
from benchmarks.latency import LatencyBenchmark
from benchmarks.accuracy import AccuracyBenchmark
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, help="Model name or path")
parser.add_argument("--method", choices=["dynamic", "static", "gptq", "awq", "bnb"], default="dynamic")
parser.add_argument("--bits", type=int, default=4, help="Quantization bits (for GPTQ/AWQ)")
parser.add_argument("--output-dir", required=True)
parser.add_argument("--benchmark", action="store_true")
args = parser.parse_args()
print(f"Loading model: {args.model}")
tokenizer = AutoTokenizer.from_pretrained(args.model)
if args.method == "dynamic":
model = AutoModelForCausalLM.from_pretrained(args.model)
quantizer = DynamicQuantizer(model)
quantized_model = quantizer.quantize()
sizes = quantizer.compare_sizes()
print(f"Compression: {sizes['compression_ratio']:.2f}x")
elif args.method == "static":
model = AutoModelForCausalLM.from_pretrained(args.model)
quantizer = StaticQuantizer(model)
prepared = quantizer.prepare()
calib_loader = create_calibration_loader(
tokenizer=tokenizer,
num_samples=256,
)
quantizer.calibrate(calib_loader)
quantized_model = quantizer.convert()
elif args.method == "gptq":
quantizer = GPTQQuantizer(
model_name=args.model,
bits=args.bits,
)
calib_data = quantizer.prepare_calibration_data()
quantized_model = quantizer.quantize(calib_data, args.output_dir)
elif args.method == "awq":
quantizer = AWQQuantizer(
model_name=args.model,
w_bit=args.bits,
)
quantized_model = quantizer.quantize(args.output_dir)
elif args.method == "bnb":
if args.bits == 8:
quantized_model, _ = BnBQuantizer.load_8bit(args.model)
else:
quantized_model, _ = BnBQuantizer.load_4bit(args.model)
# Benchmark if requested
if args.benchmark:
print("\n=== Benchmarking ===")
benchmark = LatencyBenchmark(quantized_model, tokenizer)
benchmark.warmup()
prefill_results = benchmark.benchmark_prefill()
print("\nPrefill latency:")
for length, metrics in prefill_results.items():
print(f" {length} tokens: {metrics['mean_ms']:.2f}ms ({metrics['tokens_per_sec']:.0f} tok/s)")
gen_results = benchmark.benchmark_generation(
"Once upon a time,",
output_lengths=[32, 64, 128],
)
print("\nGeneration speed:")
for length, metrics in gen_results.items():
print(f" {length} tokens: {metrics['tokens_per_sec']:.1f} tok/s")
print(f"\nModel saved to: {args.output_dir}")
if __name__ == "__main__":
main()Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Quantization | Reducing numerical precision (FP32 → INT8/INT4) | 4x size reduction, 2-4x speedup, minimal accuracy loss |
| Dynamic Quantization | Quantize weights ahead, compute activation scales at runtime | Simple, no calibration needed, good for CPU |
| Static Quantization | Pre-compute both weight and activation quantization ranges | Faster inference, requires calibration data |
| Calibration Data | Representative samples to determine optimal ranges | Ensures quantization doesn't clip important values |
| GPTQ | Second-order weight quantization for LLMs | Uses Hessian info for better 4-bit accuracy |
| AWQ | Activation-aware weight quantization | Preserves "salient" weights based on activation patterns |
| BitsAndBytes | Library for easy 8-bit and 4-bit HuggingFace loading | One-line 4-bit quantization with load_in_4bit=True |
| NF4 (NormalFloat4) | Optimal 4-bit dtype for normally-distributed weights | Better than uniform INT4 for neural network weights |
| Per-Channel Quantization | Different scale/zero-point per output channel | Better accuracy than per-tensor, especially for weights |
| Observer | Module that records min/max values during calibration | Determines optimal quantization parameters |
Next Steps
After completing this project, consider:
- Knowledge Distillation - Combine with quantization
- Model Inference API - Deploy quantized models
- LoRA Fine-tuning - QLoRA for efficient training