MLOpsBeginner
Model Serving with FastAPI
Deploy ML models as production-ready REST APIs with FastAPI
Model Serving with FastAPI
Build a production-ready ML model serving API with health checks, validation, and batching.
TL;DR
Wrap your ML model in FastAPI with Pydantic validation, add /health and /ready endpoints for Kubernetes, expose Prometheus metrics (/metrics), and support batch predictions for throughput. Use lifespan context for model loading.
What You'll Learn
- FastAPI setup for ML model serving
- Request/response validation with Pydantic
- Health checks and readiness probes
- Batch prediction for efficiency
- Structured logging and error handling
Tech Stack
| Component | Technology |
|---|---|
| Framework | FastAPI |
| Validation | Pydantic |
| ML Runtime | PyTorch / ONNX |
| Server | Uvicorn |
Architecture
┌──────────────────────────────────────────────────────────────────────────────┐
│ MODEL SERVING ARCHITECTURE │
├──────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ ┌───────────────┐ ┌───────────────────────┐ │
│ │ Client │────────▶│ Load Balancer │────────▶│ FastAPI Service │ │
│ └──────────┘ └───────────────┘ └───────────┬───────────┘ │
│ │ │
│ ┌─────────────────────────────────────┼──────────┐ │
│ │ SERVICE │ │ │
│ │ ▼ │ │
│ │ ┌────────────┐ ┌────────────────────────┐ │ │
│ │ │ Validation │───▶│ Batching │ │ │
│ │ │ (Pydantic) │ │ (group requests) │ │ │
│ │ └────────────┘ └───────────┬────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌────────────────────────┐ │ │
│ │ │ Model Inference │ │ │
│ │ │ (PyTorch/ONNX) │ │ │
│ │ └───────────┬────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌────────────────────────┐ │ │
│ │ │ Response │ │ │
│ │ └────────────────────────┘ │ │
│ └────────────────────────────────────────────────┘ │
│ │
│ Side Endpoints: │
│ ┌────────────────┐ ┌────────────────┐ ┌────────────────┐ │
│ │ GET /health │ │ GET /ready │ │ GET /metrics │ │
│ │ (liveness) │ │ (readiness) │ │ (Prometheus) │ │
│ └────────────────┘ └────────────────┘ └────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────┘Project Structure
model-serving/
├── src/
│ ├── __init__.py
│ ├── main.py # FastAPI application
│ ├── model.py # Model loading and inference
│ ├── schemas.py # Pydantic models
│ └── config.py # Configuration
├── models/ # Model artifacts
├── tests/
│ └── test_api.py
├── Dockerfile
└── requirements.txtImplementation
Step 1: Dependencies
fastapi>=0.100.0
uvicorn>=0.23.0
pydantic>=2.0.0
torch>=2.0.0
numpy>=1.24.0
python-json-logger>=2.0.0
prometheus-client>=0.17.0Step 2: Configuration
"""Application configuration."""
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""Application settings."""
# API settings
app_name: str = "Model Serving API"
debug: bool = False
# Model settings
model_path: str = "./models/model.pt"
model_name: str = "classifier"
# Inference settings
batch_size: int = 32
max_batch_wait_ms: int = 50
# Server settings
host: str = "0.0.0.0"
port: int = 8000
workers: int = 1
class Config:
env_file = ".env"
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()Step 3: Pydantic Schemas
"""Request and response schemas."""
from pydantic import BaseModel, Field
from typing import Optional
from enum import Enum
class PredictionInput(BaseModel):
"""Single prediction input."""
features: list[float] = Field(
...,
description="Input feature vector",
min_length=1,
max_length=1000
)
request_id: Optional[str] = Field(
None,
description="Optional request ID for tracking"
)
model_config = {
"json_schema_extra": {
"examples": [{
"features": [0.1, 0.5, 0.3, 0.8],
"request_id": "req-123"
}]
}
}
class BatchPredictionInput(BaseModel):
"""Batch prediction input."""
inputs: list[PredictionInput] = Field(
...,
description="List of prediction inputs",
min_length=1,
max_length=100
)
class PredictionOutput(BaseModel):
"""Single prediction output."""
prediction: int = Field(..., description="Predicted class")
confidence: float = Field(..., description="Prediction confidence", ge=0, le=1)
probabilities: list[float] = Field(..., description="Class probabilities")
request_id: Optional[str] = Field(None, description="Request ID if provided")
class BatchPredictionOutput(BaseModel):
"""Batch prediction output."""
predictions: list[PredictionOutput]
batch_size: int
processing_time_ms: float
class HealthStatus(str, Enum):
"""Health check status."""
HEALTHY = "healthy"
UNHEALTHY = "unhealthy"
DEGRADED = "degraded"
class HealthResponse(BaseModel):
"""Health check response."""
status: HealthStatus
model_loaded: bool
model_name: str
version: str = "1.0.0"Step 4: Model Wrapper
"""Model loading and inference."""
import torch
import torch.nn as nn
import numpy as np
from pathlib import Path
from typing import Optional
import logging
logger = logging.getLogger(__name__)
class SimpleClassifier(nn.Module):
"""Example classifier model."""
def __init__(self, input_size: int = 4, num_classes: int = 3):
super().__init__()
self.network = nn.Sequential(
nn.Linear(input_size, 64),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, num_classes)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
class ModelWrapper:
"""
Wrapper for model loading and inference.
Handles model lifecycle: loading, inference, and health checks.
"""
def __init__(self, model_path: str, device: str = "cpu"):
self.model_path = Path(model_path)
self.device = torch.device(device)
self.model: Optional[nn.Module] = None
self._is_loaded = False
def load(self) -> bool:
"""Load the model from disk."""
try:
if self.model_path.exists():
# Load saved model
self.model = torch.load(
self.model_path,
map_location=self.device
)
else:
# Create default model for demo
logger.warning(f"Model not found at {self.model_path}, using default")
self.model = SimpleClassifier()
self.model.to(self.device)
self.model.eval()
self._is_loaded = True
logger.info(f"Model loaded successfully on {self.device}")
return True
except Exception as e:
logger.error(f"Failed to load model: {e}")
self._is_loaded = False
return False
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._is_loaded
def predict(self, features: np.ndarray) -> dict:
"""
Run inference on input features.
Args:
features: Input array of shape (batch_size, num_features)
Returns:
Dictionary with predictions, confidences, and probabilities
"""
if not self._is_loaded:
raise RuntimeError("Model not loaded")
with torch.no_grad():
# Convert to tensor
inputs = torch.tensor(features, dtype=torch.float32).to(self.device)
# Handle single sample
if inputs.dim() == 1:
inputs = inputs.unsqueeze(0)
# Forward pass
logits = self.model(inputs)
probabilities = torch.softmax(logits, dim=-1)
# Get predictions
confidences, predictions = torch.max(probabilities, dim=-1)
return {
"predictions": predictions.cpu().numpy().tolist(),
"confidences": confidences.cpu().numpy().tolist(),
"probabilities": probabilities.cpu().numpy().tolist()
}
def warmup(self, input_size: int = 4) -> None:
"""Warm up the model with a dummy inference."""
if self._is_loaded:
dummy_input = np.random.randn(1, input_size).astype(np.float32)
self.predict(dummy_input)
logger.info("Model warmup complete")Understanding the Model Wrapper Pattern:
┌─────────────────────────────────────────────────────────────────────────────┐
│ WHY USE A MODEL WRAPPER CLASS │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Without Wrapper: With Wrapper: │
│ ┌─────────────────────────────┐ ┌─────────────────────────────────┐ │
│ │ # Globals everywhere │ │ class ModelWrapper: │ │
│ │ model = None │ │ def __init__(self, path): │ │
│ │ │ │ self.model = None │ │
│ │ @app.on_event("startup") │ │ self._is_loaded = False │ │
│ │ def load(): │ │ │ │
│ │ global model │ │ def load(self): │ │
│ │ model = torch.load(...) │ │ self.model = torch.load(..) │ │
│ │ │ │ self._is_loaded = True │ │
│ │ # Problem: Hard to test! │ │ │ │
│ │ # Problem: No health check │ │ @property │ │
│ └─────────────────────────────┘ │ def is_loaded(self): │ │
│ │ return self._is_loaded │ │
│ │ │ │
│ │ def predict(self, x): │ │
│ │ if not self._is_loaded: │ │
│ │ raise RuntimeError(...) │ │
│ └─────────────────────────────────┘ │
│ │
│ Benefits: │
│ • Encapsulated state (model, device, loaded flag) │
│ • Testable: Can mock the wrapper in tests │
│ • Health checks: is_loaded property for /ready endpoint │
│ • Error handling: predict() raises if model not loaded │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Model Warmup Explained:
┌─────────────────────────────────────────────────────────────────────────────┐
│ WHY WARMUP MATTERS │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ First inference after load (cold): │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Time: ████████████████████████████████████████ 500ms │ │
│ │ │ │ │ │ │ │
│ │ └ CUDA └ JIT compile └ Allocate └ Actual inference │ │
│ │ init kernels buffers │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ Second inference (warm): │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ Time: ████ 10ms │ │
│ │ └ Only actual inference │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ By calling predict() once during startup with dummy data: │
│ • First real user request gets the fast path │
│ • No surprise latency spikes in production │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Step 5: FastAPI Application
"""FastAPI application for model serving."""
import time
import logging
from contextlib import asynccontextmanager
from typing import Annotated
import numpy as np
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from prometheus_client import Counter, Histogram, generate_latest
from starlette.responses import Response
from .config import Settings, get_settings
from .model import ModelWrapper
from .schemas import (
PredictionInput,
BatchPredictionInput,
PredictionOutput,
BatchPredictionOutput,
HealthResponse,
HealthStatus,
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Prometheus metrics
PREDICTION_COUNTER = Counter(
"predictions_total",
"Total number of predictions",
["status"]
)
PREDICTION_LATENCY = Histogram(
"prediction_latency_seconds",
"Prediction latency in seconds",
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0]
)
BATCH_SIZE_HISTOGRAM = Histogram(
"batch_size",
"Batch sizes",
buckets=[1, 5, 10, 25, 50, 100]
)
# Global model instance
model_wrapper: ModelWrapper = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan handler."""
global model_wrapper
settings = get_settings()
# Startup: Load model
logger.info("Starting model serving API...")
model_wrapper = ModelWrapper(settings.model_path)
if model_wrapper.load():
model_wrapper.warmup()
logger.info("Model ready for inference")
else:
logger.error("Failed to load model")
yield
# Shutdown
logger.info("Shutting down model serving API...")
def create_app() -> FastAPI:
"""Create FastAPI application."""
settings = get_settings()
app = FastAPI(
title=settings.app_name,
description="Production ML Model Serving API",
version="1.0.0",
lifespan=lifespan
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
return app
app = create_app()
def get_model() -> ModelWrapper:
"""Dependency to get model instance."""
if model_wrapper is None or not model_wrapper.is_loaded:
raise HTTPException(
status_code=503,
detail="Model not available"
)
return model_wrapper
@app.get("/health", response_model=HealthResponse)
async def health_check(
settings: Annotated[Settings, Depends(get_settings)]
) -> HealthResponse:
"""
Health check endpoint.
Returns the current health status of the service.
"""
if model_wrapper and model_wrapper.is_loaded:
return HealthResponse(
status=HealthStatus.HEALTHY,
model_loaded=True,
model_name=settings.model_name
)
return HealthResponse(
status=HealthStatus.UNHEALTHY,
model_loaded=False,
model_name=settings.model_name
)
@app.get("/ready")
async def readiness_check():
"""Kubernetes readiness probe."""
if model_wrapper and model_wrapper.is_loaded:
return {"ready": True}
raise HTTPException(status_code=503, detail="Not ready")
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint."""
return Response(
content=generate_latest(),
media_type="text/plain"
)
@app.post("/predict", response_model=PredictionOutput)
async def predict(
input_data: PredictionInput,
model: Annotated[ModelWrapper, Depends(get_model)]
) -> PredictionOutput:
"""
Single prediction endpoint.
Takes a feature vector and returns the prediction.
"""
start_time = time.time()
try:
# Run inference
features = np.array([input_data.features])
result = model.predict(features)
# Record metrics
latency = time.time() - start_time
PREDICTION_LATENCY.observe(latency)
PREDICTION_COUNTER.labels(status="success").inc()
return PredictionOutput(
prediction=result["predictions"][0],
confidence=result["confidences"][0],
probabilities=result["probabilities"][0],
request_id=input_data.request_id
)
except Exception as e:
PREDICTION_COUNTER.labels(status="error").inc()
logger.error(f"Prediction failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict/batch", response_model=BatchPredictionOutput)
async def predict_batch(
batch_input: BatchPredictionInput,
model: Annotated[ModelWrapper, Depends(get_model)]
) -> BatchPredictionOutput:
"""
Batch prediction endpoint.
Takes multiple inputs and returns predictions for all.
"""
start_time = time.time()
batch_size = len(batch_input.inputs)
BATCH_SIZE_HISTOGRAM.observe(batch_size)
try:
# Prepare batch
features = np.array([inp.features for inp in batch_input.inputs])
result = model.predict(features)
# Build response
predictions = []
for i, inp in enumerate(batch_input.inputs):
predictions.append(PredictionOutput(
prediction=result["predictions"][i],
confidence=result["confidences"][i],
probabilities=result["probabilities"][i],
request_id=inp.request_id
))
processing_time = (time.time() - start_time) * 1000
PREDICTION_LATENCY.observe(processing_time / 1000)
PREDICTION_COUNTER.labels(status="success").inc(batch_size)
return BatchPredictionOutput(
predictions=predictions,
batch_size=batch_size,
processing_time_ms=processing_time
)
except Exception as e:
PREDICTION_COUNTER.labels(status="error").inc(batch_size)
logger.error(f"Batch prediction failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.middleware("http")
async def log_requests(request: Request, call_next):
"""Log all requests."""
start_time = time.time()
response = await call_next(request)
duration = time.time() - start_time
logger.info(
f"{request.method} {request.url.path} "
f"status={response.status_code} duration={duration:.3f}s"
)
return response
if __name__ == "__main__":
import uvicorn
settings = get_settings()
uvicorn.run(
"src.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug
)Step 6: Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application
COPY src/ ./src/
COPY models/ ./models/
# Create non-root user
RUN useradd --create-home appuser
USER appuser
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run server
CMD ["uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]Step 7: Tests
"""API tests."""
import pytest
from fastapi.testclient import TestClient
from src.main import app
@pytest.fixture
def client():
"""Test client fixture."""
return TestClient(app)
def test_health_check(client):
"""Test health endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "model_loaded" in data
def test_single_prediction(client):
"""Test single prediction."""
response = client.post(
"/predict",
json={"features": [0.1, 0.5, 0.3, 0.8]}
)
assert response.status_code == 200
data = response.json()
assert "prediction" in data
assert "confidence" in data
assert 0 <= data["confidence"] <= 1
def test_batch_prediction(client):
"""Test batch prediction."""
response = client.post(
"/predict/batch",
json={
"inputs": [
{"features": [0.1, 0.5, 0.3, 0.8]},
{"features": [0.2, 0.4, 0.6, 0.1]},
]
}
)
assert response.status_code == 200
data = response.json()
assert data["batch_size"] == 2
assert len(data["predictions"]) == 2
def test_invalid_input(client):
"""Test validation error."""
response = client.post(
"/predict",
json={"features": []} # Empty features
)
assert response.status_code == 422
def test_metrics_endpoint(client):
"""Test Prometheus metrics."""
response = client.get("/metrics")
assert response.status_code == 200
assert b"predictions_total" in response.contentRunning the Service
# Local development
uvicorn src.main:app --reload
# Docker
docker build -t model-serving .
docker run -p 8000:8000 model-serving
# Test prediction
curl -X POST http://localhost:8000/predict \
-H "Content-Type: application/json" \
-d '{"features": [0.1, 0.5, 0.3, 0.8]}'Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Health Check | /health endpoint for liveness probe | Kubernetes restarts unhealthy pods |
| Readiness Check | /ready endpoint for readiness probe | Traffic only routes to ready pods |
| Pydantic Validation | Type-checked request/response schemas | Catch bad inputs before inference |
| Batch Prediction | Process multiple inputs in one call | GPU parallelism, higher throughput |
| Model Wrapper | Class managing load/predict lifecycle | Clean separation, testable code |
| Lifespan Context | @asynccontextmanager for startup/shutdown | Load model once, clean shutdown |
| Prometheus Metrics | Counters, histograms for observability | Monitor latency, errors, throughput |
| Model Warmup | Dummy inference after loading | First real request isn't slow |
Next Steps
- Docker Deployment - Containerize this service
- LLM Caching - Add caching for LLM calls