Deep LearningBeginner
Model Inference API
Serve PyTorch models with FastAPI
Model Inference API
TL;DR
Deploy ML models as REST APIs with dynamic batching for throughput, GPU memory management, and Prometheus metrics. Learn model warmup to avoid cold-start latency and async request handling with FastAPI's ThreadPoolExecutor.
Deploy PyTorch and HuggingFace models as production-ready REST APIs with batching, GPU management, and monitoring.
What You'll Learn
- Loading HuggingFace models with PyTorch
- FastAPI endpoint design for ML
- Batch inference optimization
- GPU memory management
- Health checks and Prometheus metrics
Tech Stack
| Component | Technology |
|---|---|
| Framework | PyTorch, Transformers |
| API | FastAPI |
| Metrics | Prometheus |
| Async | asyncio, ThreadPool |
Architecture
┌──────────────────────────────────────────────────────────────────────────────┐
│ MODEL INFERENCE API ARCHITECTURE │
├──────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────────┐ │
│ │ API LAYER │ │
│ │ ┌────────────────┐ ┌─────────────┐ ┌─────────────────────────┐ │ │
│ │ │ Client Request │───▶│ FastAPI │───▶│ Batch Queue │ │ │
│ │ └────────────────┘ │ (async) │ │ (collect requests) │ │ │
│ │ └─────────────┘ └────────────┬────────────┘ │ │
│ └────────────────────────────────────────────────────────┼────────────────┘ │
│ │ │
│ ┌────────────────────────────────────────────────────────▼────────────────┐ │
│ │ INFERENCE ENGINE │ │
│ │ ┌─────────────────┐ ┌─────────────────┐ ┌──────────────────┐ │ │
│ │ │ Dynamic Batcher │───▶│ Model Forward │───▶│ GPU / CPU │ │ │
│ │ │ (wait or full) │ │ (torch.no_grad)│ │ (device_map) │ │ │
│ │ └─────────────────┘ └─────────────────┘ └────────┬─────────┘ │ │
│ └─────────────────────────────────────────────────────────┼──────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────┐ │
│ │ Response │ │
│ │ (predictions/embeddings)│ │
│ └─────────────────────────┘ │
│ │
├──────────────────────────────────────────────────────────────────────────────┤
│ │
│ MONITORING (Prometheus) │
│ ┌───────────────┐ ┌────────────────┐ ┌────────────────┐ ┌────────────┐ │
│ │ Request Count │ │ Latency (p50, │ │ Batch Size │ │ GPU Memory │ │
│ │ (success/err) │ │ p95, p99) │ │ Distribution │ │ (alloc/res)│ │
│ └───────────────┘ └────────────────┘ └────────────────┘ └────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────┘Project Structure
inference-api/
├── src/
│ ├── __init__.py
│ ├── model_loader.py # Model loading utilities
│ ├── inference.py # Inference engine
│ ├── batching.py # Request batching
│ └── metrics.py # Prometheus metrics
├── api/
│ └── main.py # FastAPI application
├── config/
│ └── settings.py # Configuration
├── tests/
│ └── test_api.py
├── requirements.txt
└── DockerfileImplementation
Step 1: Dependencies
torch>=2.0.0
transformers>=4.30.0
fastapi>=0.100.0
uvicorn>=0.23.0
prometheus-client>=0.17.0
pydantic-settings>=2.0.0Step 2: Configuration
"""Application configuration."""
from pydantic_settings import BaseSettings
from typing import Optional
import torch
class Settings(BaseSettings):
"""Inference API settings."""
# Model settings
model_name: str = "distilbert-base-uncased-finetuned-sst-2-english"
model_revision: Optional[str] = None
device: str = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype: str = "float16" if torch.cuda.is_available() else "float32"
# Inference settings
max_batch_size: int = 32
max_sequence_length: int = 512
batch_timeout_ms: int = 50
# Server settings
host: str = "0.0.0.0"
port: int = 8000
workers: int = 1
# Memory settings
max_memory_mb: Optional[int] = None
enable_memory_efficient: bool = True
class Config:
env_file = ".env"
settings = Settings()Step 3: Model Loader
"""Model loading utilities."""
import torch
from transformers import (
AutoModel,
AutoModelForSequenceClassification,
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer
)
from typing import Tuple, Optional, Dict, Any
import logging
import gc
logger = logging.getLogger(__name__)
class ModelLoader:
"""
Utility class for loading HuggingFace models.
Handles:
- Different model architectures
- Device placement (CPU/GPU)
- Memory optimization
- Model warmup
"""
TASK_MODEL_MAPPING = {
"classification": AutoModelForSequenceClassification,
"generation": AutoModelForCausalLM,
"embedding": AutoModel,
}
def __init__(
self,
model_name: str,
task: str = "classification",
device: str = "cuda",
dtype: str = "float16",
revision: Optional[str] = None
):
self.model_name = model_name
self.task = task
self.device = torch.device(device)
self.dtype = getattr(torch, dtype)
self.revision = revision
self.model: Optional[PreTrainedModel] = None
self.tokenizer: Optional[PreTrainedTokenizer] = None
def load(self) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""Load model and tokenizer."""
logger.info(f"Loading model: {self.model_name}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
revision=self.revision
)
# Ensure pad token exists
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Get model class
model_class = self.TASK_MODEL_MAPPING.get(self.task, AutoModel)
# Load model with appropriate settings
load_kwargs: Dict[str, Any] = {
"pretrained_model_name_or_path": self.model_name,
"revision": self.revision,
}
# Use appropriate dtype for GPU
if self.device.type == "cuda":
load_kwargs["torch_dtype"] = self.dtype
load_kwargs["device_map"] = "auto"
else:
load_kwargs["torch_dtype"] = torch.float32
self.model = model_class.from_pretrained(**load_kwargs)
# Move to device if not using device_map
if "device_map" not in load_kwargs:
self.model = self.model.to(self.device)
# Set to eval mode
self.model.eval()
# Log memory usage
if self.device.type == "cuda":
memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
logger.info(f"GPU memory used: {memory_mb:.2f} MB")
logger.info(f"Model loaded on {self.device}")
return self.model, self.tokenizer
def warmup(self, batch_size: int = 1, seq_length: int = 32) -> None:
"""Warm up the model with dummy inference."""
if self.model is None or self.tokenizer is None:
raise RuntimeError("Model not loaded")
logger.info("Warming up model...")
dummy_text = "This is a warmup text " * (seq_length // 5)
dummy_batch = [dummy_text] * batch_size
inputs = self.tokenizer(
dummy_batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=seq_length
)
# Move to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
_ = self.model(**inputs)
# Clear cache
if self.device.type == "cuda":
torch.cuda.empty_cache()
logger.info("Warmup complete")
def unload(self) -> None:
"""Unload model and free memory."""
self.model = None
self.tokenizer = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Model unloaded")Understanding the Model Loader:
┌─────────────────────────────────────────────────────────────────────────────┐
│ MODEL LOADING PIPELINE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 1. TOKENIZER LOADING │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ AutoTokenizer.from_pretrained("model_name") │ │
│ │ │ │
│ │ • Downloads vocab files from HuggingFace Hub │ │
│ │ • Sets up pad_token (required for batching) │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ 2. MODEL LOADING │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Task Mapping: │ │
│ │ • classification → AutoModelForSequenceClassification │ │
│ │ • generation → AutoModelForCausalLM │ │
│ │ • embedding → AutoModel │ │
│ │ │ │
│ │ Device Placement: │ │
│ │ • GPU: torch_dtype=float16, device_map="auto" │ │
│ │ • CPU: torch_dtype=float32 │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ 3. EVALUATION MODE │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ model.eval() │ │
│ │ │ │
│ │ • Disables dropout layers │ │
│ │ • Switches BatchNorm to inference mode │ │
│ │ • MUST be called before inference! │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Why Model Warmup Matters:
Without Warmup: With Warmup:
┌─────────────────────────┐ ┌─────────────────────────┐
│ Request 1: 2500ms ⚠️ │ │ Warmup: 2500ms (hidden) │
│ • CUDA initialization │ │ Request 1: 45ms ✓ │
│ • JIT compilation │ │ Request 2: 42ms ✓ │
│ • Memory allocation │ │ Request 3: 44ms ✓ │
│ Request 2: 45ms │ └─────────────────────────┘
│ Request 3: 44ms │
└─────────────────────────┘
First inference triggers PyTorch's lazy initialization.
Warmup moves this cost to startup time, not request time.Step 4: Request Batching
"""Request batching for efficient inference."""
import asyncio
import time
from typing import List, Dict, Any, Callable, Optional
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
@dataclass
class BatchRequest:
"""A single request in a batch."""
id: str
data: Any
future: asyncio.Future
timestamp: float
class DynamicBatcher:
"""
Dynamic request batcher for efficient GPU utilization.
Collects requests and processes them in batches to maximize
throughput while maintaining low latency.
"""
def __init__(
self,
process_fn: Callable,
max_batch_size: int = 32,
max_wait_ms: int = 50
):
self.process_fn = process_fn
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms / 1000 # Convert to seconds
self.queue: List[BatchRequest] = []
self.lock = asyncio.Lock()
self.processing = False
self._batch_task: Optional[asyncio.Task] = None
async def submit(self, request_id: str, data: Any) -> Any:
"""
Submit a request for batched processing.
Args:
request_id: Unique request identifier
data: Request data
Returns:
Processing result
"""
loop = asyncio.get_event_loop()
future = loop.create_future()
request = BatchRequest(
id=request_id,
data=data,
future=future,
timestamp=time.time()
)
async with self.lock:
self.queue.append(request)
# Start batch processing if not running
if not self.processing:
self.processing = True
self._batch_task = asyncio.create_task(
self._process_batches()
)
return await future
async def _process_batches(self) -> None:
"""Process batches from the queue."""
while True:
# Wait for batch to fill or timeout
await asyncio.sleep(self.max_wait_ms)
async with self.lock:
if not self.queue:
self.processing = False
return
# Get batch
batch = self.queue[:self.max_batch_size]
self.queue = self.queue[self.max_batch_size:]
# Process batch
try:
results = await self._process_batch(batch)
# Resolve futures
for request, result in zip(batch, results):
if not request.future.done():
request.future.set_result(result)
except Exception as e:
logger.error(f"Batch processing failed: {e}")
for request in batch:
if not request.future.done():
request.future.set_exception(e)
async def _process_batch(
self,
batch: List[BatchRequest]
) -> List[Any]:
"""Process a single batch."""
batch_data = [r.data for r in batch]
# Run in thread pool to avoid blocking
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
None,
self.process_fn,
batch_data
)
return results
class TokenBatcher:
"""
Token-aware batcher that groups requests by similar lengths.
More efficient for transformer models as it reduces padding.
"""
def __init__(
self,
process_fn: Callable,
max_batch_tokens: int = 4096,
max_batch_size: int = 32,
max_wait_ms: int = 50
):
self.process_fn = process_fn
self.max_batch_tokens = max_batch_tokens
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms / 1000
self.queue: List[BatchRequest] = []
self.lock = asyncio.Lock()
async def submit(
self,
request_id: str,
data: Any,
num_tokens: int
) -> Any:
"""Submit request with token count."""
loop = asyncio.get_event_loop()
future = loop.create_future()
request = BatchRequest(
id=request_id,
data=(data, num_tokens),
future=future,
timestamp=time.time()
)
async with self.lock:
self.queue.append(request)
self._maybe_process()
return await future
def _maybe_process(self) -> None:
"""Check if we should process a batch."""
if not self.queue:
return
# Sort by token count for efficient batching
self.queue.sort(key=lambda r: r.data[1])
# Find optimal batch
batch = []
total_tokens = 0
for request in self.queue:
_, tokens = request.data
if (total_tokens + tokens <= self.max_batch_tokens and
len(batch) < self.max_batch_size):
batch.append(request)
total_tokens += tokens
else:
break
if batch:
asyncio.create_task(self._process_batch(batch))
async def _process_batch(
self,
batch: List[BatchRequest]
) -> None:
"""Process batch."""
# Remove from queue
async with self.lock:
for request in batch:
if request in self.queue:
self.queue.remove(request)
try:
batch_data = [r.data[0] for r in batch]
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
None, self.process_fn, batch_data
)
for request, result in zip(batch, results):
request.future.set_result(result)
except Exception as e:
for request in batch:
request.future.set_exception(e)Understanding Dynamic Batching:
┌─────────────────────────────────────────────────────────────────────────────┐
│ DYNAMIC BATCHING TIMELINE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Time ──────────────────────────────────────────────────────────► │
│ │
│ 0ms 25ms 50ms 75ms 100ms │
│ │ │ │ │ │ │
│ ▼ ▼ ▼ ▼ ▼ │
│ R1 R2,R3 │ R4,R5,R6 │ │
│ │ │ │ │ │ │
│ └───────────┼───────────┘ └───────────┼───────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ Process │ │ Process │ │
│ │ Batch [R1-R3]│ │ Batch [R4-R6]│ │
│ └──────────────┘ └──────────────┘ │
│ │
│ max_wait_ms = 50ms │
│ max_batch_size = 32 │
│ │
│ Trigger batch processing when: │
│ • max_wait_ms elapsed since first request in queue, OR │
│ • max_batch_size requests collected │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Why Batching Improves Throughput:
| Approach | GPU Utilization | Latency | Throughput |
|---|---|---|---|
| One-by-one | ~5-10% | Low (per request) | Low |
| Fixed batch | ~60-80% | High (wait for full batch) | Medium |
| Dynamic batch | ~50-70% | Medium (timeout-based) | High |
Token-Aware Batching:
┌────────────────────────────────────────────────┐
│ Standard Batching (wastes compute on padding) │
│ │
│ "Hi" ───► [Hi ][PAD][PAD][PAD]...[PAD] │
│ "Hello" ───► [Hello][PAD][PAD]...[PAD] │
│ "How..." ───► [How are you doing today?] │
│ └── 80% of tokens are padding! ──┘│
└────────────────────────────────────────────────┘
┌────────────────────────────────────────────────┐
│ Token-Aware Batching (groups similar lengths) │
│ │
│ Batch 1: ["Hi", "Hey", "Yo"] ← Short │
│ Batch 2: ["How are you", ...] ← Medium │
│ Batch 3: ["Long paragraph..."] ← Long │
│ │
│ Result: Less padding = faster inference │
└────────────────────────────────────────────────┘Step 5: Inference Engine
"""Inference engine for model serving."""
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel, PreTrainedTokenizer
from typing import List, Dict, Any, Optional, Union
import logging
logger = logging.getLogger(__name__)
class InferenceEngine:
"""
Inference engine for running model predictions.
Handles:
- Text preprocessing
- Batched inference
- Output postprocessing
- GPU memory management
"""
def __init__(
self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
device: torch.device,
max_length: int = 512
):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.max_length = max_length
# Get label mapping if available
self.id2label = getattr(model.config, 'id2label', None)
@torch.no_grad()
def classify(
self,
texts: Union[str, List[str]],
return_all_scores: bool = False
) -> List[Dict[str, Any]]:
"""
Run classification inference.
Args:
texts: Input text(s)
return_all_scores: Return scores for all classes
Returns:
List of predictions with labels and scores
"""
if isinstance(texts, str):
texts = [texts]
# Tokenize
inputs = self.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
# Move to device
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Forward pass
outputs = self.model(**inputs)
logits = outputs.logits
# Get probabilities
probs = F.softmax(logits, dim=-1)
predictions = torch.argmax(probs, dim=-1)
confidences = torch.max(probs, dim=-1).values
# Format results
results = []
for i, (pred, conf) in enumerate(zip(predictions, confidences)):
result = {
"label": self.id2label[pred.item()] if self.id2label else pred.item(),
"score": conf.item()
}
if return_all_scores and self.id2label:
result["scores"] = {
self.id2label[j]: probs[i, j].item()
for j in range(probs.size(-1))
}
results.append(result)
return results
@torch.no_grad()
def embed(
self,
texts: Union[str, List[str]],
pooling: str = "mean"
) -> torch.Tensor:
"""
Generate embeddings.
Args:
texts: Input text(s)
pooling: Pooling strategy ('mean', 'cls', 'last')
Returns:
Embeddings tensor [batch_size, hidden_dim]
"""
if isinstance(texts, str):
texts = [texts]
inputs = self.tokenizer(
texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
hidden_states = outputs.last_hidden_state
# Apply pooling
if pooling == "mean":
mask = inputs["attention_mask"].unsqueeze(-1)
embeddings = (hidden_states * mask).sum(1) / mask.sum(1)
elif pooling == "cls":
embeddings = hidden_states[:, 0, :]
elif pooling == "last":
# Get last non-padding token
seq_lengths = inputs["attention_mask"].sum(1) - 1
embeddings = hidden_states[
torch.arange(hidden_states.size(0)),
seq_lengths
]
else:
raise ValueError(f"Unknown pooling: {pooling}")
# Normalize
embeddings = F.normalize(embeddings, p=2, dim=-1)
return embeddings
@torch.no_grad()
def generate(
self,
prompts: Union[str, List[str]],
max_new_tokens: int = 100,
temperature: float = 1.0,
top_p: float = 0.9,
do_sample: bool = True
) -> List[str]:
"""
Generate text completions.
Args:
prompts: Input prompt(s)
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
do_sample: Whether to sample or use greedy
Returns:
Generated texts
"""
if isinstance(prompts, str):
prompts = [prompts]
inputs = self.tokenizer(
prompts,
return_tensors="pt",
padding=True
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=self.tokenizer.pad_token_id
)
# Decode, removing input prompt
generated = self.tokenizer.batch_decode(
outputs[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)
return generatedUnderstanding Inference Patterns:
┌─────────────────────────────────────────────────────────────────────────────┐
│ INFERENCE PIPELINE COMPARISON │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ CLASSIFICATION EMBEDDING │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ Input Text │ │ Input Text │ │
│ └────────┬───────────┘ └────────┬───────────┘ │
│ ▼ ▼ │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ Tokenize + Pad │ │ Tokenize + Pad │ │
│ └────────┬───────────┘ └────────┬───────────┘ │
│ ▼ ▼ │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ Model Forward │ │ Model Forward │ │
│ │ → logits │ │ → hidden_states │ │
│ └────────┬───────────┘ └────────┬───────────┘ │
│ ▼ ▼ │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ Softmax → Probs │ │ Pooling │ │
│ │ Argmax → Label │ │ (mean/cls/last) │ │
│ └────────┬───────────┘ └────────┬───────────┘ │
│ ▼ ▼ │
│ {"label": "POSITIVE", [0.12, -0.34, 0.56, ...] │
│ "score": 0.95} (normalized vector) │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Pooling Strategies Explained:
| Strategy | How It Works | Best For |
|---|---|---|
| mean | Average all token embeddings (weighted by attention mask) | General similarity, retrieval |
| cls | Use only the [CLS] token embedding | BERT-family models trained with [CLS] |
| last | Use the last non-padding token | Decoder-only models (GPT-style) |
Step 6: Prometheus Metrics
"""Prometheus metrics for monitoring."""
from prometheus_client import Counter, Histogram, Gauge, Info
import time
from functools import wraps
from typing import Callable
import torch
# Request metrics
REQUEST_COUNT = Counter(
"inference_requests_total",
"Total inference requests",
["endpoint", "status"]
)
REQUEST_LATENCY = Histogram(
"inference_latency_seconds",
"Inference latency in seconds",
["endpoint"],
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0]
)
BATCH_SIZE = Histogram(
"inference_batch_size",
"Batch sizes",
buckets=[1, 2, 4, 8, 16, 32, 64]
)
# Model metrics
MODEL_INFO = Info(
"model_info",
"Model information"
)
GPU_MEMORY = Gauge(
"gpu_memory_bytes",
"GPU memory usage",
["type"] # allocated, reserved
)
TOKENS_PROCESSED = Counter(
"tokens_processed_total",
"Total tokens processed",
["type"] # input, output
)
def track_latency(endpoint: str) -> Callable:
"""Decorator to track endpoint latency."""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
start = time.time()
try:
result = await func(*args, **kwargs)
REQUEST_COUNT.labels(
endpoint=endpoint, status="success"
).inc()
return result
except Exception as e:
REQUEST_COUNT.labels(
endpoint=endpoint, status="error"
).inc()
raise
finally:
REQUEST_LATENCY.labels(endpoint=endpoint).observe(
time.time() - start
)
return wrapper
return decorator
def update_gpu_metrics() -> None:
"""Update GPU memory metrics."""
if torch.cuda.is_available():
GPU_MEMORY.labels(type="allocated").set(
torch.cuda.memory_allocated()
)
GPU_MEMORY.labels(type="reserved").set(
torch.cuda.memory_reserved()
)
def set_model_info(
name: str,
device: str,
dtype: str,
parameters: int
) -> None:
"""Set model information metric."""
MODEL_INFO.info({
"name": name,
"device": device,
"dtype": dtype,
"parameters": str(parameters)
})Step 7: FastAPI Application
"""FastAPI application for model inference."""
import torch
import uuid
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
from contextlib import asynccontextmanager
from config.settings import settings
from src.model_loader import ModelLoader
from src.inference import InferenceEngine
from src.batching import DynamicBatcher
from src.metrics import (
track_latency, update_gpu_metrics, set_model_info,
BATCH_SIZE, TOKENS_PROCESSED
)
# Request/Response models
class ClassifyRequest(BaseModel):
texts: List[str] = Field(..., min_length=1, max_length=32)
return_all_scores: bool = False
class ClassifyResponse(BaseModel):
predictions: List[Dict[str, Any]]
class EmbedRequest(BaseModel):
texts: List[str] = Field(..., min_length=1, max_length=32)
pooling: str = Field(default="mean", pattern="^(mean|cls|last)$")
class EmbedResponse(BaseModel):
embeddings: List[List[float]]
dimension: int
# Global instances
engine: InferenceEngine = None
batcher: DynamicBatcher = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler."""
global engine, batcher
# Load model
loader = ModelLoader(
model_name=settings.model_name,
task="classification",
device=settings.device,
dtype=settings.torch_dtype
)
model, tokenizer = loader.load()
loader.warmup()
# Create engine
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
device=torch.device(settings.device),
max_length=settings.max_sequence_length
)
# Create batcher
def process_batch(texts: List[str]) -> List[Dict]:
BATCH_SIZE.observe(len(texts))
return engine.classify(texts)
batcher = DynamicBatcher(
process_fn=process_batch,
max_batch_size=settings.max_batch_size,
max_wait_ms=settings.batch_timeout_ms
)
# Set model info metric
num_params = sum(p.numel() for p in model.parameters())
set_model_info(
name=settings.model_name,
device=settings.device,
dtype=settings.torch_dtype,
parameters=num_params
)
print(f"Model ready on {settings.device}")
yield
# Cleanup
loader.unload()
app = FastAPI(
title="Model Inference API",
description="Production ML model serving",
lifespan=lifespan
)
@app.get("/health")
async def health():
"""Health check endpoint."""
update_gpu_metrics()
return {
"status": "healthy",
"model": settings.model_name,
"device": settings.device
}
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint."""
update_gpu_metrics()
return Response(
content=generate_latest(),
media_type=CONTENT_TYPE_LATEST
)
@app.post("/classify", response_model=ClassifyResponse)
@track_latency("classify")
async def classify(request: ClassifyRequest):
"""Classify texts."""
predictions = []
for text in request.texts:
request_id = str(uuid.uuid4())
result = await batcher.submit(request_id, text)
predictions.append(result)
return ClassifyResponse(predictions=predictions)
@app.post("/classify/batch", response_model=ClassifyResponse)
@track_latency("classify_batch")
async def classify_batch(request: ClassifyRequest):
"""Classify texts in a single batch (no dynamic batching)."""
predictions = engine.classify(
request.texts,
return_all_scores=request.return_all_scores
)
BATCH_SIZE.observe(len(request.texts))
return ClassifyResponse(predictions=predictions)
@app.post("/embed", response_model=EmbedResponse)
@track_latency("embed")
async def embed(request: EmbedRequest):
"""Generate embeddings."""
embeddings = engine.embed(request.texts, pooling=request.pooling)
return EmbedResponse(
embeddings=embeddings.cpu().tolist(),
dimension=embeddings.size(-1)
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=settings.host, port=settings.port)Running the Project
# Install dependencies
pip install -r requirements.txt
# Run the API
uvicorn api.main:app --reload
# Test classification
curl -X POST http://localhost:8000/classify \
-H "Content-Type: application/json" \
-d '{"texts": ["I love this!", "This is terrible."]}'
# Test embeddings
curl -X POST http://localhost:8000/embed \
-H "Content-Type: application/json" \
-d '{"texts": ["Hello world"], "pooling": "mean"}'
# Get metrics
curl http://localhost:8000/metricsKey Concepts
Dynamic Batching
Collects requests and processes them together:
# Instead of processing one by one
for request in requests:
result = model(request) # Inefficient
# Batch processing
results = model(batch_of_requests) # EfficientGPU Memory Management
Monitor and optimize GPU memory:
# Check memory usage
torch.cuda.memory_allocated() # Currently used
torch.cuda.memory_reserved() # Reserved by PyTorch
# Clear cache
torch.cuda.empty_cache()Model Warmup
Pre-warm model to avoid cold start latency:
# First inference is slow (JIT compilation, memory allocation)
# Warmup runs inference with dummy data
dummy_input = tokenizer("warmup text", return_tensors="pt")
model(**dummy_input) # Now model is warmKey Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Dynamic Batching | Collect requests and process together | Maximizes GPU utilization, improves throughput 2-10x |
| Model Warmup | Run dummy inference before serving | Avoids cold-start latency from JIT compilation and memory allocation |
| torch.no_grad() | Disable gradient computation | Saves memory and speeds up inference (no backward pass needed) |
| device_map="auto" | Automatic model placement across GPUs | Handles large models that don't fit in single GPU |
| float16 / bfloat16 | Half-precision inference | 2x memory savings, faster on modern GPUs with minimal accuracy loss |
| Prometheus Metrics | Counters, histograms, gauges for monitoring | Track latency percentiles, throughput, and GPU memory in production |
| ThreadPoolExecutor | Run blocking inference in separate thread | Keeps FastAPI async loop responsive while model computes |
| empty_cache() | Release unused GPU memory | Prevents memory fragmentation and OOM errors |
Next Steps
- LoRA Fine-tuning - Efficient model fine-tuning
- Quantization - Optimize for faster inference