Production SLM System
Build an enterprise-grade small language model deployment with routing, fallbacks, and observability
Production SLM System
TL;DR
Build enterprise SLM infrastructure with a model registry for centralized config, complexity-based routing (simple→small models, complex→large), automatic fallback with circuit breakers, semantic caching via Redis + embeddings, and full observability through Prometheus metrics + OpenTelemetry tracing. Track costs per request and optimize with adaptive model selection.
Build a production-ready SLM infrastructure with intelligent model routing, automatic fallbacks, comprehensive monitoring, and cost optimization.
| Difficulty | Advanced |
| Time | 5 days |
| Code | ~1000 lines |
| Prerequisites | SLM Fine-tuning, Edge Deployment, MLOps basics |
What You'll Learn
- Intelligent Model Routing - Route requests to optimal models based on complexity, latency, and cost
- Fallback Strategies - Graceful degradation with automatic escalation to larger models
- Monitoring & Observability - Prometheus metrics, distributed tracing, and alerting
- Cost Optimization - Track and minimize inference costs across model tiers
- Multi-Model Orchestration - Manage multiple SLMs with unified API
- Rate Limiting & Circuit Breakers - Production resilience patterns
Architecture Overview
┌─────────────────────────────────────────────────────────────────────────────┐
│ PRODUCTION SLM SYSTEM ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ API GATEWAY │ │
│ │ │ │
│ │ Request ──► Load Balancer ──► Rate Limiter ──► Authentication │ │
│ │ │ │
│ └──────────────────────────────────┬──────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ INTELLIGENT ROUTER │ │
│ │ │ │
│ │ Query Analyzer ──► Complexity Router ──► Circuit Breaker │ │
│ │ • Token count • Score 0.0-1.0 • Failure tracking │ │
│ │ • Pattern match • Tier selection • Auto recovery │ │
│ │ • Domain detect • Load balancing • Fallback trigger │ │
│ │ │ │
│ └──────────────────────────────────┬──────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────┼────────────────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌───────────────┐ ┌───────────────┐ ┌───────────────┐ │
│ │ TIER 1: SMALL │ │ TIER 2: MEDIUM│ │ TIER 3: LARGE │ │
│ │ │ │ │ │ │ │
│ │ Phi-3 Mini │ │ Qwen2.5-7B │ │ GPT-4o-mini │ │
│ │ ~50ms/req │ │ ~150ms/req │ │ ~500ms/req │ │
│ │ Score: 0-0.4 │ │ Score: 0.3-0.7│ │ Score: 0.6-1.0│ │
│ │ Cost: $0 │ │ Cost: $0 │ │ Cost: $$ │ │
│ │ Local │ │ Local │ │ API │ │
│ └───────┬───────┘ └───────┬───────┘ └───────┬───────┘ │
│ │ │ │ │
│ └──────────────────────────┼──────────────────────────┘ │
│ │ │
│ ┌────────────────────────────┴────────────────────────────┐ │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────────────┐ ┌─────────────────────────────┐ │
│ │ CACHING LAYER │ │ OBSERVABILITY │ │
│ │ │ │ │ │
│ │ Semantic Cache (Redis) │ │ Prometheus ──► Grafana │ │
│ │ • Embedding similarity │ │ │ │ │
│ │ • 0.95 threshold │ │ └──► Alertmanager │ │
│ │ • TTL: 1 hour │ │ │ │
│ │ │ │ OpenTelemetry Tracing │ │
│ │ Response Cache │ │ • Request spans │ │
│ │ • Exact match │ │ • Model latency │ │
│ │ • Fast lookup │ │ • Error tracking │ │
│ └─────────────────────────────┘ └─────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Project Setup
mkdir production-slm && cd production-slm
python -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activatepip install torch transformers accelerate
pip install fastapi uvicorn httpx
pip install prometheus-client opentelemetry-api opentelemetry-sdk
pip install opentelemetry-instrumentation-fastapi
pip install redis sentence-transformers numpy
pip install pydantic tenacity circuitbreaker
pip install ollama openai tiktokenPart 1: Model Registry & Configuration
★ Insight ─────────────────────────────────────
Production systems need a centralized registry to manage model metadata, costs, and capabilities. This enables dynamic routing decisions without hardcoding model information throughout the codebase.
─────────────────────────────────────────────────
# models/registry.py
"""
Model Registry - Central configuration for all available models.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
import yaml
class ModelTier(Enum):
"""Model tiers based on capability and cost."""
SMALL = "small" # Fast, cheap, local
MEDIUM = "medium" # Balanced
LARGE = "large" # Most capable, expensive
class ModelType(Enum):
"""Model deployment type."""
LOCAL = "local" # Ollama or local inference
API = "api" # External API (OpenAI, etc.)
VLLM = "vllm" # vLLM server
@dataclass
class ModelConfig:
"""Configuration for a single model."""
name: str
tier: ModelTier
model_type: ModelType
# Performance characteristics
max_tokens: int = 4096
avg_latency_ms: float = 100.0
tokens_per_second: float = 50.0
# Cost (per 1M tokens)
input_cost: float = 0.0
output_cost: float = 0.0
# Capabilities
supports_json: bool = True
supports_tools: bool = False
supports_vision: bool = False
context_window: int = 4096
# Connection
endpoint: Optional[str] = None
api_key_env: Optional[str] = None
# Health
enabled: bool = True
weight: float = 1.0 # For load balancing
# Complexity thresholds
min_complexity: float = 0.0
max_complexity: float = 1.0
@dataclass
class ModelRegistry:
"""Registry of all available models."""
models: dict[str, ModelConfig] = field(default_factory=dict)
default_model: str = ""
fallback_chain: list[str] = field(default_factory=list)
def register(self, config: ModelConfig) -> None:
"""Register a model configuration."""
self.models[config.name] = config
def get(self, name: str) -> Optional[ModelConfig]:
"""Get model configuration by name."""
return self.models.get(name)
def get_by_tier(self, tier: ModelTier) -> list[ModelConfig]:
"""Get all models in a tier."""
return [m for m in self.models.values()
if m.tier == tier and m.enabled]
def get_enabled(self) -> list[ModelConfig]:
"""Get all enabled models."""
return [m for m in self.models.values() if m.enabled]
def get_fallback_chain(self, model_name: str) -> list[str]:
"""Get fallback chain for a model."""
if model_name in self.fallback_chain:
idx = self.fallback_chain.index(model_name)
return self.fallback_chain[idx + 1:]
return self.fallback_chain
@classmethod
def from_yaml(cls, path: str) -> "ModelRegistry":
"""Load registry from YAML configuration."""
with open(path) as f:
data = yaml.safe_load(f)
registry = cls(
default_model=data.get("default_model", ""),
fallback_chain=data.get("fallback_chain", [])
)
for model_data in data.get("models", []):
config = ModelConfig(
name=model_data["name"],
tier=ModelTier(model_data["tier"]),
model_type=ModelType(model_data["type"]),
max_tokens=model_data.get("max_tokens", 4096),
avg_latency_ms=model_data.get("avg_latency_ms", 100),
tokens_per_second=model_data.get("tokens_per_second", 50),
input_cost=model_data.get("input_cost", 0),
output_cost=model_data.get("output_cost", 0),
supports_json=model_data.get("supports_json", True),
supports_tools=model_data.get("supports_tools", False),
supports_vision=model_data.get("supports_vision", False),
context_window=model_data.get("context_window", 4096),
endpoint=model_data.get("endpoint"),
api_key_env=model_data.get("api_key_env"),
enabled=model_data.get("enabled", True),
weight=model_data.get("weight", 1.0),
min_complexity=model_data.get("min_complexity", 0.0),
max_complexity=model_data.get("max_complexity", 1.0),
)
registry.register(config)
return registry
def create_default_registry() -> ModelRegistry:
"""Create registry with default model configurations."""
registry = ModelRegistry(
default_model="phi3-mini",
fallback_chain=["phi3-mini", "qwen2.5-7b", "gpt-4o-mini"]
)
# Tier 1: Small local models
registry.register(ModelConfig(
name="phi3-mini",
tier=ModelTier.SMALL,
model_type=ModelType.LOCAL,
max_tokens=4096,
avg_latency_ms=50,
tokens_per_second=80,
input_cost=0,
output_cost=0,
supports_json=True,
supports_tools=True,
context_window=4096,
endpoint="http://localhost:11434",
min_complexity=0.0,
max_complexity=0.4,
))
registry.register(ModelConfig(
name="qwen2.5-3b",
tier=ModelTier.SMALL,
model_type=ModelType.LOCAL,
max_tokens=8192,
avg_latency_ms=60,
tokens_per_second=70,
input_cost=0,
output_cost=0,
supports_json=True,
supports_tools=True,
context_window=8192,
endpoint="http://localhost:11434",
min_complexity=0.0,
max_complexity=0.5,
))
# Tier 2: Medium models
registry.register(ModelConfig(
name="qwen2.5-7b",
tier=ModelTier.MEDIUM,
model_type=ModelType.LOCAL,
max_tokens=8192,
avg_latency_ms=150,
tokens_per_second=40,
input_cost=0,
output_cost=0,
supports_json=True,
supports_tools=True,
context_window=32768,
endpoint="http://localhost:11434",
min_complexity=0.3,
max_complexity=0.7,
))
registry.register(ModelConfig(
name="llama3.2-8b",
tier=ModelTier.MEDIUM,
model_type=ModelType.LOCAL,
max_tokens=8192,
avg_latency_ms=180,
tokens_per_second=35,
input_cost=0,
output_cost=0,
supports_json=True,
supports_tools=True,
context_window=128000,
endpoint="http://localhost:11434",
min_complexity=0.3,
max_complexity=0.8,
))
# Tier 3: Large API models (fallback)
registry.register(ModelConfig(
name="gpt-4o-mini",
tier=ModelTier.LARGE,
model_type=ModelType.API,
max_tokens=16384,
avg_latency_ms=500,
tokens_per_second=100,
input_cost=0.15, # per 1M tokens
output_cost=0.60,
supports_json=True,
supports_tools=True,
supports_vision=True,
context_window=128000,
api_key_env="OPENAI_API_KEY",
min_complexity=0.6,
max_complexity=1.0,
))
registry.register(ModelConfig(
name="claude-3-haiku",
tier=ModelTier.LARGE,
model_type=ModelType.API,
max_tokens=4096,
avg_latency_ms=400,
tokens_per_second=120,
input_cost=0.25,
output_cost=1.25,
supports_json=True,
supports_tools=True,
supports_vision=True,
context_window=200000,
api_key_env="ANTHROPIC_API_KEY",
min_complexity=0.6,
max_complexity=1.0,
))
return registryModel configuration file:
# config/models.yaml
default_model: phi3-mini
fallback_chain:
- phi3-mini
- qwen2.5-7b
- gpt-4o-mini
models:
- name: phi3-mini
tier: small
type: local
max_tokens: 4096
avg_latency_ms: 50
tokens_per_second: 80
context_window: 4096
supports_json: true
supports_tools: true
endpoint: http://localhost:11434
min_complexity: 0.0
max_complexity: 0.4
- name: qwen2.5-7b
tier: medium
type: local
max_tokens: 8192
avg_latency_ms: 150
tokens_per_second: 40
context_window: 32768
supports_json: true
supports_tools: true
endpoint: http://localhost:11434
min_complexity: 0.3
max_complexity: 0.7
- name: gpt-4o-mini
tier: large
type: api
max_tokens: 16384
avg_latency_ms: 500
tokens_per_second: 100
input_cost: 0.15
output_cost: 0.60
context_window: 128000
supports_json: true
supports_tools: true
supports_vision: true
api_key_env: OPENAI_API_KEY
min_complexity: 0.6
max_complexity: 1.0Understanding the Model Tiering Strategy:
┌─────────────────────────────────────────────────────────────────────────────┐
│ MODEL TIER SELECTION │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Query Complexity ─────────────────────────────────────────────────► │
│ 0.0 0.4 0.6 1.0 │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌─────────────────────┬─────────────────────┬─────────────────────┐ │
│ │ TIER 1: SMALL │ TIER 2: MEDIUM │ TIER 3: LARGE │ │
│ │ │ │ │ │
│ │ phi3-mini (3.8B) │ qwen2.5-7b │ gpt-4o-mini │ │
│ │ qwen2.5-3b │ llama3.2-8b │ claude-3-haiku │ │
│ │ │ │ │ │
│ │ Cost: $0 (local) │ Cost: $0 (local) │ Cost: $0.15-1.25 │ │
│ │ Latency: 50-60ms │ Latency: 150-180ms │ Latency: 400-500ms │ │
│ │ │ │ (per 1M tokens) │ │
│ │ "What is 2+2?" │ "Explain ML" │ "Write a business │ │
│ │ "Hello" │ "Summarize this" │ plan for..." │ │
│ └─────────────────────┴─────────────────────┴─────────────────────┘ │
│ │
│ Fallback Chain: phi3-mini → qwen2.5-7b → gpt-4o-mini │
│ (If small fails, try medium, then large) │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Key Configuration Fields:
| Field | Purpose | Example |
|---|---|---|
min/max_complexity | Query complexity range for this model | 0.0-0.4 for simple queries |
fallback_chain | Ordered list of models to try on failure | ["phi3", "qwen7b", "gpt-4o"] |
weight | Load balancing weight (for same-tier models) | 1.0 = equal, 2.0 = 2x traffic |
enabled | Quick way to disable without removing | false during maintenance |
endpoint | Where to send requests (local vs API) | http://localhost:11434 |
Part 2: Query Complexity Analysis
★ Insight ─────────────────────────────────────
Complexity analysis determines which model tier should handle a request. Simple queries (greetings, factual lookups) go to small models, while complex reasoning tasks escalate to larger models. This balances cost and quality.
─────────────────────────────────────────────────
# routing/complexity.py
"""
Query Complexity Analyzer - Determines optimal model tier.
"""
import re
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import tiktoken
class ComplexityLevel(Enum):
"""Query complexity levels."""
TRIVIAL = "trivial" # Greetings, simple questions
SIMPLE = "simple" # Factual, single-step
MODERATE = "moderate" # Some reasoning required
COMPLEX = "complex" # Multi-step reasoning
EXPERT = "expert" # Specialized knowledge, long context
@dataclass
class ComplexityAnalysis:
"""Result of complexity analysis."""
level: ComplexityLevel
score: float # 0.0 to 1.0
token_count: int
reasoning_required: bool
domain_specific: bool
multi_turn: bool
requires_tools: bool
requires_vision: bool
confidence: float
factors: dict[str, float]
class QueryComplexityAnalyzer:
"""Analyzes query complexity for routing decisions."""
# Patterns indicating complexity
TRIVIAL_PATTERNS = [
r"^(hi|hello|hey|thanks|thank you|bye|goodbye)",
r"^what (is|are) (your|the) name",
r"^how are you",
]
SIMPLE_PATTERNS = [
r"^what (is|are|was|were)\b",
r"^who (is|are|was|were)\b",
r"^when (did|was|is)\b",
r"^where (is|are|was|were)\b",
r"^define\b",
r"^translate\b",
]
COMPLEX_INDICATORS = [
r"\bwhy\b",
r"\bhow (do|does|can|could|would|should)\b",
r"\bcompare\b",
r"\banalyze\b",
r"\bexplain.*(detail|depth)",
r"\bstep.by.step\b",
r"\breason(ing)?\b",
r"\bprove\b",
r"\bderive\b",
]
EXPERT_INDICATORS = [
r"\bimplement\b",
r"\boptimize\b",
r"\bdebug\b",
r"\barchitect(ure)?\b",
r"\balgorithm\b",
r"\btheorem\b",
r"\bproof\b",
r"\bresearch\b",
]
DOMAIN_KEYWORDS = {
"code": ["function", "class", "variable", "loop", "api", "code",
"program", "bug", "error", "syntax", "compile"],
"math": ["equation", "integral", "derivative", "matrix", "vector",
"probability", "statistics", "theorem", "proof"],
"science": ["hypothesis", "experiment", "molecule", "quantum",
"physics", "chemistry", "biology", "research"],
"legal": ["contract", "liability", "statute", "regulation",
"compliance", "jurisdiction", "precedent"],
"medical": ["diagnosis", "symptom", "treatment", "medication",
"patient", "clinical", "pathology"],
}
def __init__(self, model: str = "gpt-4o"):
"""Initialize analyzer."""
try:
self.tokenizer = tiktoken.encoding_for_model(model)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
def analyze(
self,
query: str,
conversation_history: Optional[list[dict]] = None,
has_images: bool = False
) -> ComplexityAnalysis:
"""Analyze query complexity."""
factors = {}
# Token count
token_count = len(self.tokenizer.encode(query))
history_tokens = 0
if conversation_history:
for msg in conversation_history:
history_tokens += len(self.tokenizer.encode(msg.get("content", "")))
total_tokens = token_count + history_tokens
# Length factor (longer = more complex)
length_factor = min(1.0, total_tokens / 2000)
factors["length"] = length_factor
# Pattern matching
query_lower = query.lower().strip()
# Check trivial patterns
is_trivial = any(re.search(p, query_lower) for p in self.TRIVIAL_PATTERNS)
factors["trivial_pattern"] = 0.0 if is_trivial else 0.3
# Check simple patterns
is_simple = any(re.search(p, query_lower) for p in self.SIMPLE_PATTERNS)
factors["simple_pattern"] = 0.1 if is_simple else 0.3
# Check complex indicators
complex_count = sum(1 for p in self.COMPLEX_INDICATORS
if re.search(p, query_lower))
factors["complex_indicators"] = min(1.0, complex_count * 0.2)
# Check expert indicators
expert_count = sum(1 for p in self.EXPERT_INDICATORS
if re.search(p, query_lower))
factors["expert_indicators"] = min(1.0, expert_count * 0.25)
# Domain specificity
domain_matches = {}
for domain, keywords in self.DOMAIN_KEYWORDS.items():
matches = sum(1 for kw in keywords if kw in query_lower)
if matches > 0:
domain_matches[domain] = matches
domain_specific = len(domain_matches) > 0
domain_factor = min(1.0, sum(domain_matches.values()) * 0.1) if domain_matches else 0.0
factors["domain_specific"] = domain_factor
# Multi-turn factor
multi_turn = conversation_history is not None and len(conversation_history) > 0
factors["multi_turn"] = 0.2 if multi_turn else 0.0
# Vision factor
factors["vision"] = 0.3 if has_images else 0.0
# Reasoning indicators
reasoning_words = ["because", "therefore", "however", "although",
"consequently", "furthermore", "moreover"]
reasoning_count = sum(1 for w in reasoning_words if w in query_lower)
reasoning_required = reasoning_count > 0 or complex_count > 0
factors["reasoning"] = min(0.3, reasoning_count * 0.1)
# Question complexity (number of questions)
question_count = query.count("?")
factors["questions"] = min(0.3, question_count * 0.1)
# Calculate overall score
weights = {
"length": 0.15,
"trivial_pattern": 0.1,
"simple_pattern": 0.1,
"complex_indicators": 0.2,
"expert_indicators": 0.15,
"domain_specific": 0.1,
"multi_turn": 0.05,
"vision": 0.05,
"reasoning": 0.05,
"questions": 0.05,
}
score = sum(factors[k] * weights[k] for k in weights)
# Override for trivial queries
if is_trivial and token_count < 20:
score = 0.05
# Determine level
if score < 0.15:
level = ComplexityLevel.TRIVIAL
elif score < 0.35:
level = ComplexityLevel.SIMPLE
elif score < 0.55:
level = ComplexityLevel.MODERATE
elif score < 0.75:
level = ComplexityLevel.COMPLEX
else:
level = ComplexityLevel.EXPERT
# Confidence based on pattern clarity
confidence = 0.9 if is_trivial or expert_count > 2 else 0.7
return ComplexityAnalysis(
level=level,
score=score,
token_count=total_tokens,
reasoning_required=reasoning_required,
domain_specific=domain_specific,
multi_turn=multi_turn,
requires_tools=False, # Could be expanded
requires_vision=has_images,
confidence=confidence,
factors=factors,
)
# Example usage
if __name__ == "__main__":
analyzer = QueryComplexityAnalyzer()
test_queries = [
"Hi there!",
"What is Python?",
"Explain how neural networks learn through backpropagation",
"Compare and analyze the architectural differences between transformers and RNNs, "
"and explain why transformers have become dominant in NLP tasks",
"Implement a production-ready distributed training pipeline with gradient "
"checkpointing and automatic mixed precision",
]
for query in test_queries:
analysis = analyzer.analyze(query)
print(f"\nQuery: {query[:60]}...")
print(f" Level: {analysis.level.value}")
print(f" Score: {analysis.score:.3f}")
print(f" Tokens: {analysis.token_count}")Part 3: Intelligent Router
# routing/router.py
"""
Intelligent Router - Routes requests to optimal models.
"""
import random
import time
from dataclasses import dataclass
from typing import Optional
from circuitbreaker import circuit
from models.registry import ModelRegistry, ModelConfig, ModelTier
from routing.complexity import QueryComplexityAnalyzer, ComplexityAnalysis
@dataclass
class RoutingDecision:
"""Result of routing decision."""
model: ModelConfig
reason: str
complexity: ComplexityAnalysis
fallback_models: list[str]
estimated_latency_ms: float
estimated_cost: float
@dataclass
class RouterConfig:
"""Router configuration."""
# Latency constraints
max_latency_ms: Optional[float] = None
# Cost constraints
max_cost_per_request: Optional[float] = None
prefer_local: bool = True
# Load balancing
enable_load_balancing: bool = True
# Complexity routing
complexity_routing: bool = True
complexity_threshold_small: float = 0.3
complexity_threshold_medium: float = 0.6
class IntelligentRouter:
"""Routes requests to optimal models."""
def __init__(
self,
registry: ModelRegistry,
config: Optional[RouterConfig] = None
):
"""Initialize router."""
self.registry = registry
self.config = config or RouterConfig()
self.analyzer = QueryComplexityAnalyzer()
# Health tracking
self.model_health: dict[str, bool] = {
m.name: True for m in registry.get_enabled()
}
self.model_latencies: dict[str, list[float]] = {
m.name: [] for m in registry.get_enabled()
}
def route(
self,
query: str,
conversation_history: Optional[list[dict]] = None,
has_images: bool = False,
required_capabilities: Optional[set[str]] = None,
preferred_model: Optional[str] = None,
) -> RoutingDecision:
"""Route query to optimal model."""
# Analyze complexity
complexity = self.analyzer.analyze(
query, conversation_history, has_images
)
# Get candidate models
candidates = self._get_candidates(
complexity, has_images, required_capabilities
)
# Apply preferred model if valid
if preferred_model and preferred_model in [c.name for c in candidates]:
selected = self.registry.get(preferred_model)
return RoutingDecision(
model=selected,
reason="User preference",
complexity=complexity,
fallback_models=self.registry.get_fallback_chain(preferred_model),
estimated_latency_ms=selected.avg_latency_ms,
estimated_cost=self._estimate_cost(selected, complexity.token_count),
)
# Select best model
selected = self._select_model(candidates, complexity)
# Determine reason
reason = self._get_routing_reason(selected, complexity)
return RoutingDecision(
model=selected,
reason=reason,
complexity=complexity,
fallback_models=self.registry.get_fallback_chain(selected.name),
estimated_latency_ms=self._get_expected_latency(selected),
estimated_cost=self._estimate_cost(selected, complexity.token_count),
)
def _get_candidates(
self,
complexity: ComplexityAnalysis,
has_images: bool,
required_capabilities: Optional[set[str]]
) -> list[ModelConfig]:
"""Get candidate models based on requirements."""
candidates = []
for model in self.registry.get_enabled():
# Check health
if not self.model_health.get(model.name, False):
continue
# Check capabilities
if has_images and not model.supports_vision:
continue
if required_capabilities:
if "tools" in required_capabilities and not model.supports_tools:
continue
if "json" in required_capabilities and not model.supports_json:
continue
# Check complexity range
if self.config.complexity_routing:
if complexity.score < model.min_complexity:
continue
if complexity.score > model.max_complexity:
continue
# Check latency constraint
if self.config.max_latency_ms:
if model.avg_latency_ms > self.config.max_latency_ms:
continue
candidates.append(model)
# Fallback to any enabled model if no candidates
if not candidates:
candidates = [m for m in self.registry.get_enabled()
if self.model_health.get(m.name, False)]
return candidates
def _select_model(
self,
candidates: list[ModelConfig],
complexity: ComplexityAnalysis
) -> ModelConfig:
"""Select best model from candidates."""
if not candidates:
# Ultimate fallback
return self.registry.get(self.registry.default_model)
if len(candidates) == 1:
return candidates[0]
# Score candidates
scored = []
for model in candidates:
score = self._score_model(model, complexity)
scored.append((score, model))
# Sort by score (higher is better)
scored.sort(key=lambda x: x[0], reverse=True)
# Load balancing among top candidates
if self.config.enable_load_balancing:
top_candidates = [m for s, m in scored[:3]
if s >= scored[0][0] * 0.9]
if len(top_candidates) > 1:
weights = [c.weight for c in top_candidates]
return random.choices(top_candidates, weights=weights)[0]
return scored[0][1]
def _score_model(
self,
model: ModelConfig,
complexity: ComplexityAnalysis
) -> float:
"""Score a model for selection."""
score = 0.0
# Prefer local models
if self.config.prefer_local and model.model_type.value == "local":
score += 0.2
# Complexity fit (prefer models in their sweet spot)
complexity_mid = (model.min_complexity + model.max_complexity) / 2
complexity_fit = 1.0 - abs(complexity.score - complexity_mid)
score += complexity_fit * 0.3
# Latency score (lower is better)
expected_latency = self._get_expected_latency(model)
latency_score = max(0, 1.0 - expected_latency / 1000)
score += latency_score * 0.2
# Cost score (lower is better)
estimated_cost = self._estimate_cost(model, complexity.token_count)
cost_score = max(0, 1.0 - estimated_cost * 10) # Normalize
score += cost_score * 0.2
# Weight from config
score += model.weight * 0.1
return score
def _get_expected_latency(self, model: ModelConfig) -> float:
"""Get expected latency based on history."""
history = self.model_latencies.get(model.name, [])
if history:
return sum(history[-10:]) / len(history[-10:])
return model.avg_latency_ms
def _estimate_cost(self, model: ModelConfig, input_tokens: int) -> float:
"""Estimate cost for request."""
# Assume output is roughly equal to input for estimation
output_tokens = input_tokens
input_cost = (input_tokens / 1_000_000) * model.input_cost
output_cost = (output_tokens / 1_000_000) * model.output_cost
return input_cost + output_cost
def _get_routing_reason(
self,
model: ModelConfig,
complexity: ComplexityAnalysis
) -> str:
"""Generate human-readable routing reason."""
reasons = []
# Complexity match
if complexity.score <= self.config.complexity_threshold_small:
reasons.append(f"Simple query ({complexity.level.value})")
elif complexity.score <= self.config.complexity_threshold_medium:
reasons.append(f"Moderate complexity ({complexity.level.value})")
else:
reasons.append(f"Complex query ({complexity.level.value})")
# Model tier
reasons.append(f"{model.tier.value} tier model")
# Local preference
if model.model_type.value == "local":
reasons.append("local inference preferred")
return " - ".join(reasons)
def update_health(self, model_name: str, healthy: bool) -> None:
"""Update model health status."""
self.model_health[model_name] = healthy
def record_latency(self, model_name: str, latency_ms: float) -> None:
"""Record actual latency for a model."""
if model_name not in self.model_latencies:
self.model_latencies[model_name] = []
self.model_latencies[model_name].append(latency_ms)
# Keep last 100 measurements
if len(self.model_latencies[model_name]) > 100:
self.model_latencies[model_name] = \
self.model_latencies[model_name][-100:]Part 4: Model Clients & Fallback
★ Insight ─────────────────────────────────────
Unified model clients abstract away differences between local (Ollama) and API-based models. Circuit breakers prevent cascading failures, and automatic fallback ensures requests complete even when primary models fail.
─────────────────────────────────────────────────
# clients/base.py
"""
Model Clients - Unified interface for different model backends.
"""
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import AsyncGenerator, Optional
import httpx
from circuitbreaker import circuit
from tenacity import retry, stop_after_attempt, wait_exponential
@dataclass
class GenerationResult:
"""Result of model generation."""
text: str
model: str
input_tokens: int
output_tokens: int
latency_ms: float
finish_reason: str
cached: bool = False
@dataclass
class GenerationRequest:
"""Request for model generation."""
messages: list[dict]
max_tokens: int = 1024
temperature: float = 0.7
stream: bool = False
json_mode: bool = False
stop: Optional[list[str]] = None
class ModelClient(ABC):
"""Abstract base class for model clients."""
@abstractmethod
async def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate completion."""
pass
@abstractmethod
async def generate_stream(
self, request: GenerationRequest
) -> AsyncGenerator[str, None]:
"""Generate completion with streaming."""
pass
@abstractmethod
async def health_check(self) -> bool:
"""Check if model is healthy."""
pass
class OllamaClient(ModelClient):
"""Client for Ollama models."""
def __init__(self, model_name: str, endpoint: str = "http://localhost:11434"):
"""Initialize Ollama client."""
self.model_name = model_name
self.endpoint = endpoint
self.client = httpx.AsyncClient(timeout=120.0)
@circuit(failure_threshold=3, recovery_timeout=30)
@retry(stop=stop_after_attempt(2), wait=wait_exponential(min=1, max=10))
async def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate completion using Ollama."""
start_time = time.time()
payload = {
"model": self.model_name,
"messages": request.messages,
"stream": False,
"options": {
"temperature": request.temperature,
"num_predict": request.max_tokens,
}
}
if request.json_mode:
payload["format"] = "json"
if request.stop:
payload["options"]["stop"] = request.stop
response = await self.client.post(
f"{self.endpoint}/api/chat",
json=payload
)
response.raise_for_status()
data = response.json()
latency_ms = (time.time() - start_time) * 1000
return GenerationResult(
text=data["message"]["content"],
model=self.model_name,
input_tokens=data.get("prompt_eval_count", 0),
output_tokens=data.get("eval_count", 0),
latency_ms=latency_ms,
finish_reason=data.get("done_reason", "stop"),
)
async def generate_stream(
self, request: GenerationRequest
) -> AsyncGenerator[str, None]:
"""Generate completion with streaming."""
payload = {
"model": self.model_name,
"messages": request.messages,
"stream": True,
"options": {
"temperature": request.temperature,
"num_predict": request.max_tokens,
}
}
if request.json_mode:
payload["format"] = "json"
async with self.client.stream(
"POST",
f"{self.endpoint}/api/chat",
json=payload
) as response:
async for line in response.aiter_lines():
if line:
import json
data = json.loads(line)
if "message" in data and "content" in data["message"]:
yield data["message"]["content"]
async def health_check(self) -> bool:
"""Check if Ollama is healthy."""
try:
response = await self.client.get(f"{self.endpoint}/api/tags")
return response.status_code == 200
except Exception:
return False
class OpenAIClient(ModelClient):
"""Client for OpenAI API models."""
def __init__(self, model_name: str, api_key_env: str = "OPENAI_API_KEY"):
"""Initialize OpenAI client."""
self.model_name = model_name
self.api_key = os.environ.get(api_key_env, "")
self.client = httpx.AsyncClient(
timeout=60.0,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
)
self.endpoint = "https://api.openai.com/v1/chat/completions"
@circuit(failure_threshold=3, recovery_timeout=60)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=1, max=30))
async def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate completion using OpenAI API."""
start_time = time.time()
payload = {
"model": self.model_name,
"messages": request.messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"stream": False,
}
if request.json_mode:
payload["response_format"] = {"type": "json_object"}
if request.stop:
payload["stop"] = request.stop
response = await self.client.post(self.endpoint, json=payload)
response.raise_for_status()
data = response.json()
latency_ms = (time.time() - start_time) * 1000
choice = data["choices"][0]
usage = data.get("usage", {})
return GenerationResult(
text=choice["message"]["content"],
model=self.model_name,
input_tokens=usage.get("prompt_tokens", 0),
output_tokens=usage.get("completion_tokens", 0),
latency_ms=latency_ms,
finish_reason=choice.get("finish_reason", "stop"),
)
async def generate_stream(
self, request: GenerationRequest
) -> AsyncGenerator[str, None]:
"""Generate completion with streaming."""
payload = {
"model": self.model_name,
"messages": request.messages,
"max_tokens": request.max_tokens,
"temperature": request.temperature,
"stream": True,
}
async with self.client.stream(
"POST", self.endpoint, json=payload
) as response:
async for line in response.aiter_lines():
if line.startswith("data: ") and not line.endswith("[DONE]"):
import json
data = json.loads(line[6:])
if data["choices"][0]["delta"].get("content"):
yield data["choices"][0]["delta"]["content"]
async def health_check(self) -> bool:
"""Check if OpenAI API is accessible."""
try:
response = await self.client.get(
"https://api.openai.com/v1/models",
headers={"Authorization": f"Bearer {self.api_key}"}
)
return response.status_code == 200
except Exception:
return False
class FallbackClient:
"""Client with automatic fallback to other models."""
def __init__(
self,
primary: ModelClient,
fallbacks: list[ModelClient],
on_fallback: Optional[callable] = None
):
"""Initialize fallback client."""
self.primary = primary
self.fallbacks = fallbacks
self.on_fallback = on_fallback
async def generate(self, request: GenerationRequest) -> GenerationResult:
"""Generate with automatic fallback."""
clients = [self.primary] + self.fallbacks
last_error = None
for i, client in enumerate(clients):
try:
result = await client.generate(request)
# Notify if we fell back
if i > 0 and self.on_fallback:
self.on_fallback(
primary=self.primary,
fallback=client,
attempt=i
)
return result
except Exception as e:
last_error = e
continue
# All clients failed
raise RuntimeError(
f"All models failed. Last error: {last_error}"
)
async def generate_stream(
self, request: GenerationRequest
) -> AsyncGenerator[str, None]:
"""Generate stream with fallback (falls back on first error)."""
clients = [self.primary] + self.fallbacks
for client in clients:
try:
async for chunk in client.generate_stream(request):
yield chunk
return
except Exception:
continue
raise RuntimeError("All models failed for streaming")Part 5: Semantic Caching
# caching/semantic.py
"""
Semantic Cache - Cache responses based on query similarity.
"""
import hashlib
import json
import time
from dataclasses import dataclass
from typing import Optional
import numpy as np
from sentence_transformers import SentenceTransformer
import redis
@dataclass
class CacheEntry:
"""Cached response entry."""
response: str
model: str
input_tokens: int
output_tokens: int
created_at: float
hit_count: int = 0
@dataclass
class CacheResult:
"""Result of cache lookup."""
hit: bool
entry: Optional[CacheEntry] = None
similarity: float = 0.0
class SemanticCache:
"""Semantic similarity-based cache."""
def __init__(
self,
redis_url: str = "redis://localhost:6379",
embedding_model: str = "all-MiniLM-L6-v2",
similarity_threshold: float = 0.95,
ttl_seconds: int = 3600,
max_entries: int = 10000
):
"""Initialize semantic cache."""
self.redis = redis.from_url(redis_url)
self.encoder = SentenceTransformer(embedding_model)
self.similarity_threshold = similarity_threshold
self.ttl_seconds = ttl_seconds
self.max_entries = max_entries
# Index name for embeddings
self.index_key = "semantic_cache:index"
self.entries_key = "semantic_cache:entries"
self.embeddings_key = "semantic_cache:embeddings"
def _compute_key(self, query: str, system_prompt: Optional[str] = None) -> str:
"""Compute cache key from query."""
content = query
if system_prompt:
content = f"{system_prompt}|||{query}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _embed(self, text: str) -> np.ndarray:
"""Compute embedding for text."""
return self.encoder.encode(text, convert_to_numpy=True)
async def get(
self,
query: str,
system_prompt: Optional[str] = None
) -> CacheResult:
"""Look up query in cache."""
# Compute query embedding
query_embedding = self._embed(query)
# Get all cached embeddings
all_keys = self.redis.smembers(self.index_key)
if not all_keys:
return CacheResult(hit=False)
best_similarity = 0.0
best_key = None
for key in all_keys:
# Get stored embedding
emb_data = self.redis.hget(self.embeddings_key, key)
if not emb_data:
continue
stored_embedding = np.frombuffer(emb_data, dtype=np.float32)
# Compute cosine similarity
similarity = np.dot(query_embedding, stored_embedding) / (
np.linalg.norm(query_embedding) * np.linalg.norm(stored_embedding)
)
if similarity > best_similarity:
best_similarity = similarity
best_key = key
# Check if best match exceeds threshold
if best_similarity >= self.similarity_threshold and best_key:
# Get cached entry
entry_data = self.redis.hget(self.entries_key, best_key)
if entry_data:
entry_dict = json.loads(entry_data)
entry = CacheEntry(**entry_dict)
# Update hit count
entry.hit_count += 1
self.redis.hset(
self.entries_key,
best_key,
json.dumps(entry.__dict__)
)
return CacheResult(
hit=True,
entry=entry,
similarity=best_similarity
)
return CacheResult(hit=False, similarity=best_similarity)
async def set(
self,
query: str,
response: str,
model: str,
input_tokens: int,
output_tokens: int,
system_prompt: Optional[str] = None
) -> str:
"""Store response in cache."""
key = self._compute_key(query, system_prompt)
# Compute and store embedding
embedding = self._embed(query)
self.redis.hset(
self.embeddings_key,
key,
embedding.astype(np.float32).tobytes()
)
# Store entry
entry = CacheEntry(
response=response,
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
created_at=time.time(),
hit_count=0
)
self.redis.hset(self.entries_key, key, json.dumps(entry.__dict__))
self.redis.sadd(self.index_key, key)
# Set TTL on entry
self.redis.expire(self.entries_key, self.ttl_seconds)
# Prune if over limit
await self._prune_if_needed()
return key
async def _prune_if_needed(self) -> None:
"""Remove old entries if cache is too large."""
current_size = self.redis.scard(self.index_key)
if current_size > self.max_entries:
# Get all entries with timestamps
entries_with_time = []
for key in self.redis.smembers(self.index_key):
entry_data = self.redis.hget(self.entries_key, key)
if entry_data:
entry = json.loads(entry_data)
entries_with_time.append((key, entry["created_at"]))
# Sort by time, remove oldest
entries_with_time.sort(key=lambda x: x[1])
to_remove = entries_with_time[:current_size - self.max_entries]
for key, _ in to_remove:
self.redis.srem(self.index_key, key)
self.redis.hdel(self.entries_key, key)
self.redis.hdel(self.embeddings_key, key)
def get_stats(self) -> dict:
"""Get cache statistics."""
total_entries = self.redis.scard(self.index_key)
total_hits = 0
total_tokens_saved = 0
for key in self.redis.smembers(self.index_key):
entry_data = self.redis.hget(self.entries_key, key)
if entry_data:
entry = json.loads(entry_data)
total_hits += entry["hit_count"]
total_tokens_saved += entry["hit_count"] * (
entry["input_tokens"] + entry["output_tokens"]
)
return {
"total_entries": total_entries,
"total_hits": total_hits,
"tokens_saved": total_tokens_saved,
}
def clear(self) -> None:
"""Clear all cache entries."""
self.redis.delete(self.index_key)
self.redis.delete(self.entries_key)
self.redis.delete(self.embeddings_key)Part 6: Monitoring & Observability
★ Insight ─────────────────────────────────────
Production systems need comprehensive observability. Prometheus metrics track performance, OpenTelemetry provides distributed tracing across services, and structured logging enables debugging. The combination allows you to understand system behavior and quickly diagnose issues.
─────────────────────────────────────────────────
# monitoring/metrics.py
"""
Monitoring - Prometheus metrics and OpenTelemetry tracing.
"""
import time
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional
from prometheus_client import Counter, Histogram, Gauge, Info
# Request metrics
REQUESTS_TOTAL = Counter(
"slm_requests_total",
"Total number of requests",
["model", "tier", "status"]
)
REQUEST_LATENCY = Histogram(
"slm_request_latency_seconds",
"Request latency in seconds",
["model", "tier"],
buckets=[0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0]
)
TOKENS_TOTAL = Counter(
"slm_tokens_total",
"Total tokens processed",
["model", "direction"] # direction: input/output
)
# Routing metrics
ROUTING_DECISIONS = Counter(
"slm_routing_decisions_total",
"Routing decisions by reason",
["selected_model", "complexity_level", "reason"]
)
COMPLEXITY_SCORES = Histogram(
"slm_complexity_scores",
"Query complexity score distribution",
buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
)
# Fallback metrics
FALLBACKS_TOTAL = Counter(
"slm_fallbacks_total",
"Total fallback events",
["from_model", "to_model", "reason"]
)
# Cache metrics
CACHE_OPERATIONS = Counter(
"slm_cache_operations_total",
"Cache operations",
["operation", "result"] # operation: get/set, result: hit/miss/error
)
CACHE_SIZE = Gauge(
"slm_cache_size",
"Current cache size"
)
# Cost metrics
COST_TOTAL = Counter(
"slm_cost_dollars_total",
"Total cost in dollars",
["model"]
)
# Health metrics
MODEL_HEALTH = Gauge(
"slm_model_health",
"Model health status (1=healthy, 0=unhealthy)",
["model"]
)
ACTIVE_REQUESTS = Gauge(
"slm_active_requests",
"Currently active requests",
["model"]
)
@dataclass
class RequestMetrics:
"""Metrics for a single request."""
model: str
tier: str
latency_ms: float
input_tokens: int
output_tokens: int
status: str
complexity_score: float
complexity_level: str
routing_reason: str
cost: float
cached: bool
fallback_used: bool
fallback_from: Optional[str] = None
class MetricsCollector:
"""Collects and records metrics."""
def record_request(self, metrics: RequestMetrics) -> None:
"""Record metrics for a completed request."""
# Request counter
REQUESTS_TOTAL.labels(
model=metrics.model,
tier=metrics.tier,
status=metrics.status
).inc()
# Latency histogram
REQUEST_LATENCY.labels(
model=metrics.model,
tier=metrics.tier
).observe(metrics.latency_ms / 1000) # Convert to seconds
# Token counters
TOKENS_TOTAL.labels(
model=metrics.model,
direction="input"
).inc(metrics.input_tokens)
TOKENS_TOTAL.labels(
model=metrics.model,
direction="output"
).inc(metrics.output_tokens)
# Routing decision
ROUTING_DECISIONS.labels(
selected_model=metrics.model,
complexity_level=metrics.complexity_level,
reason=metrics.routing_reason
).inc()
# Complexity score
COMPLEXITY_SCORES.observe(metrics.complexity_score)
# Cache hit/miss
if metrics.cached:
CACHE_OPERATIONS.labels(
operation="get",
result="hit"
).inc()
# Cost
COST_TOTAL.labels(model=metrics.model).inc(metrics.cost)
# Fallback
if metrics.fallback_used and metrics.fallback_from:
FALLBACKS_TOTAL.labels(
from_model=metrics.fallback_from,
to_model=metrics.model,
reason="error"
).inc()
def record_cache_miss(self) -> None:
"""Record cache miss."""
CACHE_OPERATIONS.labels(
operation="get",
result="miss"
).inc()
def record_cache_set(self) -> None:
"""Record cache set."""
CACHE_OPERATIONS.labels(
operation="set",
result="success"
).inc()
def update_model_health(self, model: str, healthy: bool) -> None:
"""Update model health status."""
MODEL_HEALTH.labels(model=model).set(1 if healthy else 0)
def update_cache_size(self, size: int) -> None:
"""Update cache size gauge."""
CACHE_SIZE.set(size)
@contextmanager
def track_active_request(self, model: str):
"""Context manager to track active requests."""
ACTIVE_REQUESTS.labels(model=model).inc()
try:
yield
finally:
ACTIVE_REQUESTS.labels(model=model).dec()
# OpenTelemetry tracing setup
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource
def setup_tracing(service_name: str = "production-slm") -> trace.Tracer:
"""Set up OpenTelemetry tracing."""
resource = Resource.create({"service.name": service_name})
provider = TracerProvider(resource=resource)
# Export to OTLP collector (Jaeger, etc.)
exporter = OTLPSpanExporter(
endpoint="http://localhost:4317",
insecure=True
)
provider.add_span_processor(BatchSpanProcessor(exporter))
trace.set_tracer_provider(provider)
return trace.get_tracer(service_name)
# Tracer instance
tracer = setup_tracing()
def trace_request(func):
"""Decorator to trace requests."""
async def wrapper(*args, **kwargs):
with tracer.start_as_current_span(
func.__name__,
attributes={"function": func.__name__}
) as span:
try:
result = await func(*args, **kwargs)
span.set_status(trace.Status(trace.StatusCode.OK))
return result
except Exception as e:
span.set_status(
trace.Status(trace.StatusCode.ERROR, str(e))
)
span.record_exception(e)
raise
return wrapperPart 7: Cost Optimization
# optimization/cost.py
"""
Cost Optimization - Track and minimize inference costs.
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional
import json
@dataclass
class CostRecord:
"""Record of a single cost event."""
timestamp: datetime
model: str
input_tokens: int
output_tokens: int
input_cost: float
output_cost: float
total_cost: float
cached: bool
@dataclass
class CostBudget:
"""Cost budget configuration."""
daily_limit: float = 100.0 # Daily budget in dollars
hourly_limit: float = 10.0 # Hourly budget
per_request_limit: float = 0.50 # Max cost per request
warning_threshold: float = 0.8 # Warn at 80% of budget
class CostTracker:
"""Tracks and optimizes inference costs."""
def __init__(self, budget: Optional[CostBudget] = None):
"""Initialize cost tracker."""
self.budget = budget or CostBudget()
self.records: list[CostRecord] = []
self.hourly_costs: dict[str, float] = {}
self.daily_costs: dict[str, float] = {}
self.model_costs: dict[str, float] = {}
def record(
self,
model: str,
input_tokens: int,
output_tokens: int,
input_cost_per_million: float,
output_cost_per_million: float,
cached: bool = False
) -> CostRecord:
"""Record a cost event."""
# Calculate costs (0 if cached)
if cached:
input_cost = 0.0
output_cost = 0.0
else:
input_cost = (input_tokens / 1_000_000) * input_cost_per_million
output_cost = (output_tokens / 1_000_000) * output_cost_per_million
total_cost = input_cost + output_cost
record = CostRecord(
timestamp=datetime.now(),
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
input_cost=input_cost,
output_cost=output_cost,
total_cost=total_cost,
cached=cached
)
self.records.append(record)
# Update aggregations
hour_key = record.timestamp.strftime("%Y-%m-%d-%H")
day_key = record.timestamp.strftime("%Y-%m-%d")
self.hourly_costs[hour_key] = \
self.hourly_costs.get(hour_key, 0) + total_cost
self.daily_costs[day_key] = \
self.daily_costs.get(day_key, 0) + total_cost
self.model_costs[model] = \
self.model_costs.get(model, 0) + total_cost
return record
def check_budget(self) -> dict:
"""Check current budget status."""
now = datetime.now()
hour_key = now.strftime("%Y-%m-%d-%H")
day_key = now.strftime("%Y-%m-%d")
hourly_spent = self.hourly_costs.get(hour_key, 0)
daily_spent = self.daily_costs.get(day_key, 0)
return {
"hourly_spent": hourly_spent,
"hourly_limit": self.budget.hourly_limit,
"hourly_remaining": max(0, self.budget.hourly_limit - hourly_spent),
"hourly_percent": (hourly_spent / self.budget.hourly_limit) * 100,
"daily_spent": daily_spent,
"daily_limit": self.budget.daily_limit,
"daily_remaining": max(0, self.budget.daily_limit - daily_spent),
"daily_percent": (daily_spent / self.budget.daily_limit) * 100,
"warning": (
hourly_spent >= self.budget.hourly_limit * self.budget.warning_threshold or
daily_spent >= self.budget.daily_limit * self.budget.warning_threshold
),
"exceeded": (
hourly_spent >= self.budget.hourly_limit or
daily_spent >= self.budget.daily_limit
)
}
def should_use_cheaper_model(self) -> bool:
"""Determine if we should prefer cheaper models."""
status = self.check_budget()
return status["daily_percent"] > 60 or status["hourly_percent"] > 70
def get_cost_breakdown(
self,
start: Optional[datetime] = None,
end: Optional[datetime] = None
) -> dict:
"""Get detailed cost breakdown."""
# Filter records by time range
filtered = self.records
if start:
filtered = [r for r in filtered if r.timestamp >= start]
if end:
filtered = [r for r in filtered if r.timestamp <= end]
# Aggregate by model
by_model = {}
for record in filtered:
if record.model not in by_model:
by_model[record.model] = {
"requests": 0,
"input_tokens": 0,
"output_tokens": 0,
"total_cost": 0,
"cached_requests": 0,
"cost_saved_by_cache": 0
}
stats = by_model[record.model]
stats["requests"] += 1
stats["input_tokens"] += record.input_tokens
stats["output_tokens"] += record.output_tokens
stats["total_cost"] += record.total_cost
if record.cached:
stats["cached_requests"] += 1
# Calculate totals
total_cost = sum(r.total_cost for r in filtered)
total_requests = len(filtered)
cached_requests = sum(1 for r in filtered if r.cached)
return {
"total_cost": total_cost,
"total_requests": total_requests,
"cached_requests": cached_requests,
"cache_hit_rate": (cached_requests / total_requests * 100) if total_requests > 0 else 0,
"avg_cost_per_request": total_cost / total_requests if total_requests > 0 else 0,
"by_model": by_model
}
def get_optimization_suggestions(self) -> list[str]:
"""Get suggestions for cost optimization."""
suggestions = []
breakdown = self.get_cost_breakdown()
# Check cache hit rate
if breakdown["cache_hit_rate"] < 20:
suggestions.append(
f"Low cache hit rate ({breakdown['cache_hit_rate']:.1f}%). "
"Consider lowering similarity threshold."
)
# Check model usage
for model, stats in breakdown["by_model"].items():
if "gpt-4" in model.lower() or "claude" in model.lower():
pct = (stats["requests"] / breakdown["total_requests"]) * 100
if pct > 30:
suggestions.append(
f"{model} handles {pct:.1f}% of requests. "
"Consider routing more to local models."
)
# Check average cost
if breakdown["avg_cost_per_request"] > 0.01:
suggestions.append(
f"Average cost ${breakdown['avg_cost_per_request']:.4f}/request. "
"Review complexity thresholds."
)
return suggestionsPart 8: Production API Server
# server.py
"""
Production SLM API Server.
"""
import asyncio
import time
from contextlib import asynccontextmanager
from typing import Optional
from fastapi import FastAPI, HTTPException, BackgroundTasks
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from prometheus_client import make_asgi_app
from models.registry import create_default_registry, ModelType
from routing.router import IntelligentRouter, RouterConfig
from routing.complexity import ComplexityLevel
from clients.base import (
OllamaClient, OpenAIClient, FallbackClient,
GenerationRequest, GenerationResult
)
from caching.semantic import SemanticCache
from monitoring.metrics import MetricsCollector, RequestMetrics, tracer
from optimization.cost import CostTracker, CostBudget
# Request/Response models
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[ChatMessage]
model: Optional[str] = None
max_tokens: int = 1024
temperature: float = 0.7
stream: bool = False
use_cache: bool = True
class ChatResponse(BaseModel):
text: str
model: str
usage: dict
routing: dict
cached: bool
latency_ms: float
class HealthResponse(BaseModel):
status: str
models: dict[str, bool]
cache: dict
budget: dict
# Global instances
registry = None
router = None
clients = {}
cache = None
metrics = None
cost_tracker = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize and cleanup resources."""
global registry, router, clients, cache, metrics, cost_tracker
# Initialize registry
registry = create_default_registry()
# Initialize router
router_config = RouterConfig(
prefer_local=True,
enable_load_balancing=True,
complexity_routing=True
)
router = IntelligentRouter(registry, router_config)
# Initialize clients
for model in registry.get_enabled():
if model.model_type == ModelType.LOCAL:
clients[model.name] = OllamaClient(
model_name=model.name,
endpoint=model.endpoint
)
elif model.model_type == ModelType.API:
clients[model.name] = OpenAIClient(
model_name=model.name,
api_key_env=model.api_key_env
)
# Initialize cache
try:
cache = SemanticCache(
redis_url="redis://localhost:6379",
similarity_threshold=0.92,
ttl_seconds=3600
)
except Exception:
cache = None
print("Warning: Redis not available, caching disabled")
# Initialize metrics and cost tracking
metrics = MetricsCollector()
cost_tracker = CostTracker(CostBudget(
daily_limit=50.0,
hourly_limit=5.0
))
# Run health checks
asyncio.create_task(periodic_health_check())
yield
# Cleanup
for client in clients.values():
if hasattr(client, 'client'):
await client.client.aclose()
async def periodic_health_check():
"""Periodically check model health."""
while True:
for name, client in clients.items():
try:
healthy = await client.health_check()
router.update_health(name, healthy)
metrics.update_model_health(name, healthy)
except Exception:
router.update_health(name, False)
metrics.update_model_health(name, False)
await asyncio.sleep(30)
# Create FastAPI app
app = FastAPI(
title="Production SLM API",
description="Enterprise-grade small language model deployment",
version="1.0.0",
lifespan=lifespan
)
# Mount Prometheus metrics
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.post("/v1/chat", response_model=ChatResponse)
async def chat(request: ChatRequest, background_tasks: BackgroundTasks):
"""Chat completion endpoint with intelligent routing."""
start_time = time.time()
# Convert messages
messages = [{"role": m.role, "content": m.content} for m in request.messages]
query = messages[-1]["content"] if messages else ""
history = messages[:-1] if len(messages) > 1 else None
# Check cache
cached = False
if request.use_cache and cache:
cache_result = await cache.get(query)
if cache_result.hit:
latency_ms = (time.time() - start_time) * 1000
# Record metrics
request_metrics = RequestMetrics(
model=cache_result.entry.model,
tier="cached",
latency_ms=latency_ms,
input_tokens=cache_result.entry.input_tokens,
output_tokens=cache_result.entry.output_tokens,
status="success",
complexity_score=0,
complexity_level="cached",
routing_reason="cache_hit",
cost=0,
cached=True,
fallback_used=False
)
metrics.record_request(request_metrics)
return ChatResponse(
text=cache_result.entry.response,
model=cache_result.entry.model,
usage={
"input_tokens": cache_result.entry.input_tokens,
"output_tokens": cache_result.entry.output_tokens
},
routing={
"selected_model": cache_result.entry.model,
"reason": "cache_hit",
"similarity": cache_result.similarity
},
cached=True,
latency_ms=latency_ms
)
else:
metrics.record_cache_miss()
# Route request
with tracer.start_as_current_span("route_request") as span:
routing_decision = router.route(
query=query,
conversation_history=history,
preferred_model=request.model
)
span.set_attribute("model", routing_decision.model.name)
span.set_attribute("complexity", routing_decision.complexity.score)
# Get client with fallbacks
primary_client = clients.get(routing_decision.model.name)
fallback_clients = [
clients[name] for name in routing_decision.fallback_models
if name in clients
]
client = FallbackClient(
primary=primary_client,
fallbacks=fallback_clients,
on_fallback=lambda p, f, a: metrics.record_request(RequestMetrics(
model=f.model_name if hasattr(f, 'model_name') else 'unknown',
tier="fallback",
latency_ms=0,
input_tokens=0,
output_tokens=0,
status="fallback",
complexity_score=routing_decision.complexity.score,
complexity_level=routing_decision.complexity.level.value,
routing_reason="fallback",
cost=0,
cached=False,
fallback_used=True,
fallback_from=p.model_name if hasattr(p, 'model_name') else 'unknown'
))
)
# Generate response
gen_request = GenerationRequest(
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=request.stream
)
try:
with metrics.track_active_request(routing_decision.model.name):
result = await client.generate(gen_request)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
metrics.record_request(RequestMetrics(
model=routing_decision.model.name,
tier=routing_decision.model.tier.value,
latency_ms=latency_ms,
input_tokens=0,
output_tokens=0,
status="error",
complexity_score=routing_decision.complexity.score,
complexity_level=routing_decision.complexity.level.value,
routing_reason=routing_decision.reason,
cost=0,
cached=False,
fallback_used=False
))
raise HTTPException(status_code=500, detail=str(e))
# Update latency tracking
router.record_latency(result.model, result.latency_ms)
# Calculate cost
model_config = registry.get(result.model)
cost = cost_tracker.record(
model=result.model,
input_tokens=result.input_tokens,
output_tokens=result.output_tokens,
input_cost_per_million=model_config.input_cost if model_config else 0,
output_cost_per_million=model_config.output_cost if model_config else 0,
cached=False
)
# Cache response
if request.use_cache and cache and not cached:
background_tasks.add_task(
cache.set,
query=query,
response=result.text,
model=result.model,
input_tokens=result.input_tokens,
output_tokens=result.output_tokens
)
metrics.record_cache_set()
latency_ms = (time.time() - start_time) * 1000
# Record metrics
request_metrics = RequestMetrics(
model=result.model,
tier=routing_decision.model.tier.value,
latency_ms=latency_ms,
input_tokens=result.input_tokens,
output_tokens=result.output_tokens,
status="success",
complexity_score=routing_decision.complexity.score,
complexity_level=routing_decision.complexity.level.value,
routing_reason=routing_decision.reason,
cost=cost.total_cost,
cached=False,
fallback_used=False
)
metrics.record_request(request_metrics)
return ChatResponse(
text=result.text,
model=result.model,
usage={
"input_tokens": result.input_tokens,
"output_tokens": result.output_tokens
},
routing={
"selected_model": routing_decision.model.name,
"reason": routing_decision.reason,
"complexity": {
"score": routing_decision.complexity.score,
"level": routing_decision.complexity.level.value
},
"estimated_latency_ms": routing_decision.estimated_latency_ms,
"estimated_cost": routing_decision.estimated_cost
},
cached=False,
latency_ms=latency_ms
)
@app.post("/v1/chat/stream")
async def chat_stream(request: ChatRequest):
"""Streaming chat completion endpoint."""
messages = [{"role": m.role, "content": m.content} for m in request.messages]
query = messages[-1]["content"] if messages else ""
history = messages[:-1] if len(messages) > 1 else None
# Route request
routing_decision = router.route(
query=query,
conversation_history=history,
preferred_model=request.model
)
client = clients.get(routing_decision.model.name)
if not client:
raise HTTPException(status_code=500, detail="No available model")
gen_request = GenerationRequest(
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=True
)
async def generate():
async for chunk in client.generate_stream(gen_request):
yield f"data: {chunk}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream"
)
@app.get("/health", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
model_health = {}
for name, client in clients.items():
try:
model_health[name] = await client.health_check()
except Exception:
model_health[name] = False
cache_stats = cache.get_stats() if cache else {}
budget_status = cost_tracker.check_budget()
overall_status = "healthy" if any(model_health.values()) else "unhealthy"
return HealthResponse(
status=overall_status,
models=model_health,
cache=cache_stats,
budget=budget_status
)
@app.get("/v1/models")
async def list_models():
"""List available models."""
models = []
for model in registry.get_enabled():
health = await clients[model.name].health_check() \
if model.name in clients else False
models.append({
"name": model.name,
"tier": model.tier.value,
"type": model.model_type.value,
"healthy": health,
"capabilities": {
"json": model.supports_json,
"tools": model.supports_tools,
"vision": model.supports_vision
},
"limits": {
"max_tokens": model.max_tokens,
"context_window": model.context_window
},
"pricing": {
"input_per_million": model.input_cost,
"output_per_million": model.output_cost
}
})
return {"models": models}
@app.get("/v1/costs")
async def get_costs():
"""Get cost breakdown."""
return cost_tracker.get_cost_breakdown()
@app.get("/v1/costs/suggestions")
async def get_cost_suggestions():
"""Get cost optimization suggestions."""
return {
"suggestions": cost_tracker.get_optimization_suggestions(),
"budget_status": cost_tracker.check_budget()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)Part 9: Docker Deployment
# docker-compose.yml
version: '3.8'
services:
slm-api:
build:
context: .
dockerfile: Dockerfile
ports:
- "8000:8000"
environment:
- OPENAI_API_KEY=${OPENAI_API_KEY}
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
- REDIS_URL=redis://redis:6379
- OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4317
depends_on:
- redis
- ollama
- jaeger
volumes:
- ./config:/app/config
deploy:
resources:
limits:
memory: 4G
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s
timeout: 10s
retries: 3
ollama:
image: ollama/ollama:latest
ports:
- "11434:11434"
volumes:
- ollama_data:/root/.ollama
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:11434/api/tags"]
interval: 30s
timeout: 10s
retries: 3
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
command: redis-server --appendonly yes
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
- prometheus_data:/prometheus
command:
- '--config.file=/etc/prometheus/prometheus.yml'
- '--storage.tsdb.path=/prometheus'
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards
- ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources
- grafana_data:/var/lib/grafana
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
- GF_USERS_ALLOW_SIGN_UP=false
jaeger:
image: jaegertracing/all-in-one:latest
ports:
- "16686:16686" # UI
- "4317:4317" # OTLP gRPC
- "4318:4318" # OTLP HTTP
environment:
- COLLECTOR_OTLP_ENABLED=true
alertmanager:
image: prom/alertmanager:latest
ports:
- "9093:9093"
volumes:
- ./monitoring/alertmanager.yml:/etc/alertmanager/alertmanager.yml
volumes:
ollama_data:
redis_data:
prometheus_data:
grafana_data:# Dockerfile
FROM python:3.11-slim
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Create non-root user
RUN useradd -m -u 1000 appuser && chown -R appuser:appuser /app
USER appuser
# Expose port
EXPOSE 8000
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Run server
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]Prometheus configuration:
# monitoring/prometheus.yml
global:
scrape_interval: 15s
evaluation_interval: 15s
alerting:
alertmanagers:
- static_configs:
- targets:
- alertmanager:9093
rule_files:
- /etc/prometheus/rules/*.yml
scrape_configs:
- job_name: 'slm-api'
static_configs:
- targets: ['slm-api:8000']
metrics_path: /metrics
- job_name: 'prometheus'
static_configs:
- targets: ['localhost:9090']Alert rules:
# monitoring/rules/alerts.yml
groups:
- name: slm-alerts
rules:
- alert: HighErrorRate
expr: rate(slm_requests_total{status="error"}[5m]) / rate(slm_requests_total[5m]) > 0.05
for: 2m
labels:
severity: warning
annotations:
summary: High error rate detected
description: Error rate is {{ $value | printf "%.2f" }}%
- alert: HighLatency
expr: histogram_quantile(0.95, rate(slm_request_latency_seconds_bucket[5m])) > 5
for: 5m
labels:
severity: warning
annotations:
summary: High latency detected
description: P95 latency is {{ $value | printf "%.2f" }}s
- alert: ModelUnhealthy
expr: slm_model_health == 0
for: 1m
labels:
severity: critical
annotations:
summary: Model {{ $labels.model }} is unhealthy
- alert: BudgetExceeded
expr: increase(slm_cost_dollars_total[1h]) > 5
for: 0m
labels:
severity: warning
annotations:
summary: Hourly budget exceeded
description: Spent ${{ $value | printf "%.2f" }} in the last hour
- alert: CacheLowHitRate
expr: rate(slm_cache_operations_total{result="hit"}[1h]) / rate(slm_cache_operations_total{operation="get"}[1h]) < 0.1
for: 30m
labels:
severity: info
annotations:
summary: Cache hit rate is low
description: Cache hit rate is {{ $value | printf "%.1f" }}%Running the System
# Pull models in Ollama first
docker compose up -d ollama
docker exec -it production-slm-ollama-1 ollama pull phi3:mini
docker exec -it production-slm-ollama-1 ollama pull qwen2.5:7b
# Start all services
docker compose up -d
# Check health
curl http://localhost:8000/health
# Test chat endpoint
curl -X POST http://localhost:8000/v1/chat \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "Explain quantum computing"}],
"max_tokens": 256
}'
# View metrics
open http://localhost:9090 # Prometheus
open http://localhost:3000 # Grafana
open http://localhost:16686 # Jaeger tracesExercises
- Add Model A/B Testing: Implement A/B testing to compare model performance with traffic splitting
- Implement Rate Limiting: Add per-user rate limiting with token bucket algorithm
- Add Batch Processing: Create endpoint for processing multiple requests efficiently
- Custom Routing Rules: Allow users to define custom routing rules via configuration
- Implement Model Warmup: Pre-warm models on startup to reduce cold start latency
Summary
You've built a production-grade SLM deployment system with:
- Model Registry: Centralized configuration for multiple models
- Intelligent Routing: Complexity-based routing to optimal models
- Automatic Fallback: Graceful degradation when models fail
- Semantic Caching: Reduce costs by caching similar queries
- Comprehensive Monitoring: Prometheus metrics and OpenTelemetry tracing
- Cost Optimization: Track and minimize inference costs
- Docker Deployment: Production-ready containerized deployment
This architecture enables serving millions of requests while maintaining quality, minimizing costs, and ensuring reliability through automatic fallbacks and comprehensive observability.
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Model Registry | Centralized config for all models (tier, cost, capabilities) | Dynamic routing without hardcoding; easy to add/remove models |
| Complexity Analysis | Score queries 0-1 based on patterns, length, domain | Route simple queries to small models, complex to large |
| Intelligent Router | Select optimal model based on complexity, latency, cost | Maximize quality while minimizing cost and latency |
| Circuit Breaker | Track failures, open circuit after threshold, auto-recover | Prevent cascading failures; fail fast when model is down |
| Automatic Fallback | Try backup models when primary fails | Ensure requests complete even during outages |
| Semantic Caching | Cache by query embedding similarity, not exact match | Similar questions get cached answers; significant cost savings |
| Prometheus Metrics | Counters, histograms, gauges for all operations | Monitor latency, throughput, errors, costs in real-time |
| OpenTelemetry Tracing | Distributed tracing across routing and inference | Debug slow requests; understand request flow |
| Cost Tracking | Record input/output tokens and model pricing per request | Budget enforcement; optimize model selection |
| Adaptive Routing | Adjust model selection based on budget consumption | Prefer cheaper models when approaching budget limits |
Next Steps
- Explore Speculative Decoding for faster inference
- Review SLM Agents for agentic workflows
- Study MLOps Complete Pipeline for CI/CD integration