MLOpsIntermediate
LLM Caching Layer
Reduce costs and latency with intelligent semantic caching for LLM calls
LLM Caching Layer
Build an intelligent caching system that reduces LLM API costs by up to 90% through semantic similarity matching.
TL;DR
Cache LLM responses by embedding queries and matching semantically similar ones (cosine similarity > 0.92). Store embeddings + responses in Redis with TTL. On cache hit, return stored response; on miss, call LLM and cache. Track cost savings.
What You'll Learn
- Semantic caching with embeddings
- Redis for high-performance caching
- Cache invalidation strategies
- Cost tracking and optimization
- TTL and eviction policies
Tech Stack
| Component | Technology |
|---|---|
| Cache Store | Redis |
| Embeddings | OpenAI / Sentence Transformers |
| Similarity | FAISS / Redis Vector |
| API | FastAPI |
Architecture
┌──────────────────────────────────────────────────────────────────────────────┐
│ SEMANTIC CACHING FLOW │
├──────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────────┐ ┌────────────────────────┐ │
│ │ Query │─────▶│ Embed Query │─────▶│ Similar in Cache? │ │
│ │ "What is │ │ (vector) │ │ (cosine > 0.92) │ │
│ │ Python?" │ └─────────────────┘ └───────────┬────────────┘ │
│ └─────────────┘ │ │
│ ┌──────────────┴──────────────┐ │
│ │ │ │
│ YES ▼ NO ▼ │
│ ┌─────────────────┐ ┌───────────┐ │
│ │ Return Cached │ │ Call LLM │ │
│ │ Response │ │ API │ │
│ │ (instant, $0) │ └─────┬─────┘ │
│ └─────────────────┘ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Store Response │ │
│ │ + Embedding │ │
│ └────────┬────────┘ │
│ ┌──────────────────────────────────────────────┐ │ │
│ │ CACHE LAYER │ │ │
│ │ ┌─────────────────┐ ┌─────────────────┐ │ │ │
│ │ │ Redis │ │ Embeddings │◀──┼──────────────┘ │
│ │ │ (responses, │ │ (similarity │ │ │
│ │ │ TTL, stats) │ │ search) │ │ │
│ │ └─────────────────┘ └─────────────────┘ │ │
│ └──────────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────┘Project Structure
llm-caching/
├── src/
│ ├── __init__.py
│ ├── cache.py # Caching logic
│ ├── embeddings.py # Embedding generation
│ ├── llm.py # LLM client wrapper
│ ├── metrics.py # Cost tracking
│ └── api.py # FastAPI application
├── tests/
├── docker-compose.yml
└── requirements.txtImplementation
Step 1: Dependencies
fastapi>=0.100.0
uvicorn>=0.23.0
redis>=5.0.0
openai>=1.0.0
numpy>=1.24.0
sentence-transformers>=2.2.0
prometheus-client>=0.17.0Step 2: Embedding Generator
"""Embedding generation for semantic caching."""
from abc import ABC, abstractmethod
import numpy as np
from openai import OpenAI
from sentence_transformers import SentenceTransformer
import hashlib
class EmbeddingProvider(ABC):
"""Abstract embedding provider."""
@abstractmethod
def embed(self, text: str) -> np.ndarray:
"""Generate embedding for text."""
pass
@abstractmethod
def embed_batch(self, texts: list[str]) -> np.ndarray:
"""Generate embeddings for multiple texts."""
pass
class OpenAIEmbeddings(EmbeddingProvider):
"""OpenAI embedding provider."""
def __init__(self, model: str = "text-embedding-3-small"):
self.client = OpenAI()
self.model = model
self.dimension = 1536
def embed(self, text: str) -> np.ndarray:
"""Generate embedding using OpenAI."""
response = self.client.embeddings.create(
model=self.model,
input=text
)
return np.array(response.data[0].embedding)
def embed_batch(self, texts: list[str]) -> np.ndarray:
"""Generate batch embeddings."""
response = self.client.embeddings.create(
model=self.model,
input=texts
)
return np.array([d.embedding for d in response.data])
class LocalEmbeddings(EmbeddingProvider):
"""Local sentence-transformers embedding provider."""
def __init__(self, model: str = "all-MiniLM-L6-v2"):
self.model = SentenceTransformer(model)
self.dimension = self.model.get_sentence_embedding_dimension()
def embed(self, text: str) -> np.ndarray:
"""Generate embedding locally."""
return self.model.encode(text, convert_to_numpy=True)
def embed_batch(self, texts: list[str]) -> np.ndarray:
"""Generate batch embeddings."""
return self.model.encode(texts, convert_to_numpy=True)
def compute_hash(text: str) -> str:
"""Compute hash for exact match caching."""
return hashlib.sha256(text.encode()).hexdigest()[:16]Understanding Embedding Provider Choice:
┌─────────────────────────────────────────────────────────────────────────────┐
│ EMBEDDING PROVIDER COMPARISON │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ LOCAL (sentence-transformers) vs OPENAI (API) │
│ ┌─────────────────────────────┐ ┌─────────────────────────────┐ │
│ │ all-MiniLM-L6-v2 │ │ text-embedding-3-small │ │
│ │ │ │ │ │
│ │ Cost: $0 (free) │ │ Cost: $0.02/1M tokens │ │
│ │ Latency: ~5ms │ │ Latency: ~100ms │ │
│ │ Quality: Good (384 dims) │ │ Quality: Better (1536 dims) │ │
│ │ │ │ │ │
│ │ Best for: High-volume │ │ Best for: Higher accuracy │ │
│ │ caching, cost-sensitive │ │ semantic matching │ │
│ └─────────────────────────────┘ └─────────────────────────────┘ │
│ │
│ For caching: Local is usually better (faster, free, good enough) │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Step 3: Semantic Cache
"""Semantic caching with Redis."""
import json
import time
from dataclasses import dataclass
from typing import Optional
import numpy as np
import redis
@dataclass
class CacheEntry:
"""Cached LLM response."""
query: str
response: str
model: str
embedding: list[float]
created_at: float
hit_count: int = 0
tokens_saved: int = 0
@dataclass
class CacheResult:
"""Cache lookup result."""
hit: bool
response: Optional[str] = None
similarity: float = 0.0
entry_id: Optional[str] = None
class SemanticCache:
"""
Semantic caching layer for LLM responses.
Uses embedding similarity to find cached responses
for semantically similar queries.
"""
def __init__(
self,
redis_client: redis.Redis,
embedding_provider,
similarity_threshold: float = 0.95,
ttl_seconds: int = 3600,
max_entries: int = 10000
):
self.redis = redis_client
self.embeddings = embedding_provider
self.threshold = similarity_threshold
self.ttl = ttl_seconds
self.max_entries = max_entries
# Keys
self.cache_prefix = "llm:cache:"
self.embedding_key = "llm:embeddings"
self.stats_key = "llm:stats"
def get(self, query: str, model: str = "gpt-4") -> CacheResult:
"""
Look up query in cache.
Args:
query: The user query
model: Model name for cache partitioning
Returns:
CacheResult with hit status and response if found
"""
# Generate query embedding
query_embedding = self.embeddings.embed(query)
# Find similar entries
similar_id, similarity = self._find_similar(
query_embedding, model
)
if similar_id and similarity >= self.threshold:
# Cache hit
entry = self._get_entry(similar_id)
if entry:
self._update_hit_count(similar_id)
self._record_stats("hit")
return CacheResult(
hit=True,
response=entry.response,
similarity=similarity,
entry_id=similar_id
)
self._record_stats("miss")
return CacheResult(hit=False, similarity=similarity)
def set(
self,
query: str,
response: str,
model: str = "gpt-4",
tokens_used: int = 0
) -> str:
"""
Store response in cache.
Args:
query: The user query
response: LLM response
model: Model name
tokens_used: Tokens used for cost tracking
Returns:
Cache entry ID
"""
# Generate embedding
embedding = self.embeddings.embed(query)
# Create entry
entry = CacheEntry(
query=query,
response=response,
model=model,
embedding=embedding.tolist(),
created_at=time.time(),
tokens_saved=tokens_used
)
# Store in Redis
entry_id = f"{self.cache_prefix}{model}:{hash(query) % 10**10}"
self.redis.setex(
entry_id,
self.ttl,
json.dumps(entry.__dict__)
)
# Store embedding for similarity search
self._store_embedding(entry_id, embedding, model)
# Enforce max entries
self._enforce_limit(model)
return entry_id
def _find_similar(
self,
query_embedding: np.ndarray,
model: str
) -> tuple[Optional[str], float]:
"""Find most similar cached entry."""
# Get all embeddings for model
pattern = f"{self.cache_prefix}{model}:*"
keys = list(self.redis.scan_iter(pattern, count=100))
if not keys:
return None, 0.0
best_id = None
best_similarity = 0.0
for key in keys:
entry_data = self.redis.get(key)
if not entry_data:
continue
entry = json.loads(entry_data)
cached_embedding = np.array(entry["embedding"])
# Cosine similarity
similarity = self._cosine_similarity(
query_embedding, cached_embedding
)
if similarity > best_similarity:
best_similarity = similarity
best_id = key.decode() if isinstance(key, bytes) else key
return best_id, best_similarity
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between vectors."""
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def _get_entry(self, entry_id: str) -> Optional[CacheEntry]:
"""Get cache entry by ID."""
data = self.redis.get(entry_id)
if not data:
return None
entry_dict = json.loads(data)
return CacheEntry(**entry_dict)
def _update_hit_count(self, entry_id: str) -> None:
"""Increment hit count for entry."""
data = self.redis.get(entry_id)
if data:
entry = json.loads(data)
entry["hit_count"] = entry.get("hit_count", 0) + 1
# Refresh TTL on hit
self.redis.setex(entry_id, self.ttl, json.dumps(entry))
def _store_embedding(
self,
entry_id: str,
embedding: np.ndarray,
model: str
) -> None:
"""Store embedding for similarity search."""
key = f"{self.embedding_key}:{model}"
self.redis.hset(key, entry_id, embedding.tobytes())
def _enforce_limit(self, model: str) -> None:
"""Remove oldest entries if over limit."""
pattern = f"{self.cache_prefix}{model}:*"
keys = list(self.redis.scan_iter(pattern))
if len(keys) > self.max_entries:
# Get entries with timestamps
entries = []
for key in keys:
data = self.redis.get(key)
if data:
entry = json.loads(data)
entries.append((key, entry.get("created_at", 0)))
# Sort by timestamp and remove oldest
entries.sort(key=lambda x: x[1])
to_remove = len(entries) - self.max_entries
for key, _ in entries[:to_remove]:
self.redis.delete(key)
def _record_stats(self, event: str) -> None:
"""Record cache statistics."""
self.redis.hincrby(self.stats_key, event, 1)
def get_stats(self) -> dict:
"""Get cache statistics."""
stats = self.redis.hgetall(self.stats_key)
hits = int(stats.get(b"hit", 0))
misses = int(stats.get(b"miss", 0))
total = hits + misses
return {
"hits": hits,
"misses": misses,
"hit_rate": hits / total if total > 0 else 0,
"total_requests": total
}
def clear(self, model: Optional[str] = None) -> int:
"""Clear cache entries."""
if model:
pattern = f"{self.cache_prefix}{model}:*"
else:
pattern = f"{self.cache_prefix}*"
keys = list(self.redis.scan_iter(pattern))
if keys:
return self.redis.delete(*keys)
return 0Understanding the Cache Lookup Flow:
┌─────────────────────────────────────────────────────────────────────────────┐
│ CACHE GET() FLOW │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Query: "What is machine learning?" │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ 1. EMBED QUERY │ │
│ │ query_embedding = embeddings.embed(query) │ │
│ │ [0.12, -0.34, 0.56, ...] (384 or 1536 dimensions) │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ 2. SCAN CACHED ENTRIES │ │
│ │ For each cached entry: │ │
│ │ • Load cached embedding │ │
│ │ • Compute cosine_similarity(query_emb, cached_emb) │ │
│ │ • Track best match │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ 3. THRESHOLD CHECK │ │
│ │ │ │
│ │ Best similarity: 0.94 │ │
│ │ Threshold: 0.92 │ │
│ │ │ │
│ │ 0.94 >= 0.92 → CACHE HIT ✓ │ │
│ │ │ │
│ │ Return cached response: "Machine learning is a subset..." │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Threshold Tuning Guide:
| Threshold | Hit Rate | Risk | Use Case |
|---|---|---|---|
| 0.98 | Low (~20%) | Very safe | Critical applications |
| 0.95 | Medium (~40%) | Safe | General use |
| 0.92 | High (~60%) | Some false positives | Cost-sensitive |
| 0.88 | Very high (~75%) | More false positives | FAQ/known queries |
Step 4: LLM Client with Caching
"""LLM client with integrated caching."""
from dataclasses import dataclass
from typing import Optional
from openai import OpenAI
import tiktoken
@dataclass
class LLMResponse:
"""LLM response with metadata."""
content: str
cached: bool
similarity: float
tokens_used: int
cost_saved: float
class CachedLLMClient:
"""
LLM client with semantic caching.
Automatically caches responses and returns
cached results for similar queries.
"""
# Cost per 1K tokens (GPT-4 pricing)
COSTS = {
"gpt-4": {"input": 0.03, "output": 0.06},
"gpt-4-turbo": {"input": 0.01, "output": 0.03},
"gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
}
def __init__(
self,
cache,
model: str = "gpt-4-turbo",
temperature: float = 0.7
):
self.client = OpenAI()
self.cache = cache
self.model = model
self.temperature = temperature
self.encoder = tiktoken.encoding_for_model("gpt-4")
def complete(
self,
prompt: str,
system: str = "You are a helpful assistant.",
use_cache: bool = True
) -> LLMResponse:
"""
Generate completion with caching.
Args:
prompt: User prompt
system: System message
use_cache: Whether to use caching
Returns:
LLMResponse with content and metadata
"""
# Check cache first
if use_cache:
cache_result = self.cache.get(prompt, self.model)
if cache_result.hit:
# Estimate tokens saved
tokens = self._count_tokens(prompt + cache_result.response)
cost_saved = self._calculate_cost(tokens, tokens)
return LLMResponse(
content=cache_result.response,
cached=True,
similarity=cache_result.similarity,
tokens_used=0,
cost_saved=cost_saved
)
# Call LLM
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": prompt}
],
temperature=self.temperature
)
content = response.choices[0].message.content
tokens_used = response.usage.total_tokens
# Store in cache
if use_cache:
self.cache.set(
query=prompt,
response=content,
model=self.model,
tokens_used=tokens_used
)
return LLMResponse(
content=content,
cached=False,
similarity=0.0,
tokens_used=tokens_used,
cost_saved=0.0
)
def _count_tokens(self, text: str) -> int:
"""Count tokens in text."""
return len(self.encoder.encode(text))
def _calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
"""Calculate API cost."""
costs = self.COSTS.get(self.model, self.COSTS["gpt-4"])
return (
(input_tokens / 1000) * costs["input"] +
(output_tokens / 1000) * costs["output"]
)Step 5: Metrics Tracking
"""Cost and performance metrics."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
import redis
@dataclass
class CostMetrics:
"""Track LLM costs and savings."""
total_requests: int = 0
cache_hits: int = 0
cache_misses: int = 0
tokens_used: int = 0
tokens_saved: int = 0
cost_incurred: float = 0.0
cost_saved: float = 0.0
@property
def hit_rate(self) -> float:
"""Cache hit rate."""
total = self.cache_hits + self.cache_misses
return self.cache_hits / total if total > 0 else 0
@property
def savings_rate(self) -> float:
"""Cost savings rate."""
total = self.cost_incurred + self.cost_saved
return self.cost_saved / total if total > 0 else 0
class MetricsCollector:
"""Collect and report caching metrics."""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
self.key_prefix = "metrics:"
def record_request(
self,
cached: bool,
tokens_used: int,
cost: float,
tokens_saved: int = 0,
cost_saved: float = 0.0
) -> None:
"""Record a request."""
pipe = self.redis.pipeline()
# Increment counters
pipe.hincrby(f"{self.key_prefix}totals", "requests", 1)
pipe.hincrby(
f"{self.key_prefix}totals",
"hits" if cached else "misses",
1
)
pipe.hincrbyfloat(f"{self.key_prefix}totals", "tokens_used", tokens_used)
pipe.hincrbyfloat(f"{self.key_prefix}totals", "cost", cost)
pipe.hincrbyfloat(f"{self.key_prefix}totals", "tokens_saved", tokens_saved)
pipe.hincrbyfloat(f"{self.key_prefix}totals", "cost_saved", cost_saved)
# Daily metrics
today = datetime.now().strftime("%Y-%m-%d")
pipe.hincrby(f"{self.key_prefix}daily:{today}", "requests", 1)
pipe.hincrbyfloat(f"{self.key_prefix}daily:{today}", "cost", cost)
pipe.hincrbyfloat(f"{self.key_prefix}daily:{today}", "cost_saved", cost_saved)
pipe.execute()
def get_metrics(self) -> CostMetrics:
"""Get current metrics."""
data = self.redis.hgetall(f"{self.key_prefix}totals")
return CostMetrics(
total_requests=int(data.get(b"requests", 0)),
cache_hits=int(data.get(b"hits", 0)),
cache_misses=int(data.get(b"misses", 0)),
tokens_used=int(float(data.get(b"tokens_used", 0))),
tokens_saved=int(float(data.get(b"tokens_saved", 0))),
cost_incurred=float(data.get(b"cost", 0)),
cost_saved=float(data.get(b"cost_saved", 0))
)
def get_daily_metrics(self, date: Optional[str] = None) -> dict:
"""Get metrics for a specific day."""
if date is None:
date = datetime.now().strftime("%Y-%m-%d")
data = self.redis.hgetall(f"{self.key_prefix}daily:{date}")
return {
"date": date,
"requests": int(data.get(b"requests", 0)),
"cost": float(data.get(b"cost", 0)),
"cost_saved": float(data.get(b"cost_saved", 0))
}Step 6: FastAPI Application
"""FastAPI application with caching."""
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import redis
from .cache import SemanticCache
from .embeddings import LocalEmbeddings
from .llm import CachedLLMClient
from .metrics import MetricsCollector
# Global instances
cache: SemanticCache = None
llm_client: CachedLLMClient = None
metrics: MetricsCollector = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan."""
global cache, llm_client, metrics
# Initialize Redis
redis_client = redis.Redis(host="localhost", port=6379, db=0)
# Initialize components
embeddings = LocalEmbeddings()
cache = SemanticCache(
redis_client=redis_client,
embedding_provider=embeddings,
similarity_threshold=0.92,
ttl_seconds=3600
)
llm_client = CachedLLMClient(cache=cache)
metrics = MetricsCollector(redis_client)
yield
redis_client.close()
app = FastAPI(
title="LLM Caching API",
description="Intelligent caching for LLM calls",
lifespan=lifespan
)
class CompletionRequest(BaseModel):
"""Completion request."""
prompt: str
system: Optional[str] = "You are a helpful assistant."
use_cache: bool = True
class CompletionResponse(BaseModel):
"""Completion response."""
content: str
cached: bool
similarity: float
tokens_used: int
cost_saved: float
@app.post("/complete", response_model=CompletionResponse)
async def complete(request: CompletionRequest):
"""Generate completion with caching."""
try:
response = llm_client.complete(
prompt=request.prompt,
system=request.system,
use_cache=request.use_cache
)
# Record metrics
metrics.record_request(
cached=response.cached,
tokens_used=response.tokens_used,
cost=0, # Calculate from tokens
cost_saved=response.cost_saved
)
return CompletionResponse(
content=response.content,
cached=response.cached,
similarity=response.similarity,
tokens_used=response.tokens_used,
cost_saved=response.cost_saved
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/metrics")
async def get_metrics():
"""Get caching metrics."""
return metrics.get_metrics().__dict__
@app.get("/cache/stats")
async def cache_stats():
"""Get cache statistics."""
return cache.get_stats()
@app.delete("/cache")
async def clear_cache(model: Optional[str] = None):
"""Clear cache entries."""
count = cache.clear(model)
return {"cleared": count}Step 7: Docker Compose
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
environment:
- REDIS_HOST=redis
- OPENAI_API_KEY=${OPENAI_API_KEY}
depends_on:
- redis
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis-data:/data
volumes:
redis-data:Usage Example
# Direct usage
from src.cache import SemanticCache
from src.embeddings import LocalEmbeddings
from src.llm import CachedLLMClient
import redis
# Setup
redis_client = redis.Redis()
embeddings = LocalEmbeddings()
cache = SemanticCache(redis_client, embeddings)
client = CachedLLMClient(cache)
# First call - goes to LLM
response1 = client.complete("What is Python?")
print(f"Cached: {response1.cached}") # False
# Similar query - cache hit
response2 = client.complete("Tell me about Python")
print(f"Cached: {response2.cached}") # True
print(f"Similarity: {response2.similarity}") # ~0.95Tuning the Cache
| Parameter | Effect | Recommended |
|---|---|---|
| similarity_threshold | Lower = more hits, less accuracy | 0.90-0.95 |
| ttl_seconds | Cache lifetime | 1-24 hours |
| max_entries | Memory usage | 10K-100K |
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Semantic Cache | Match queries by meaning, not exact string | "What is Python?" ≈ "Tell me about Python" |
| Embedding Provider | Converts text to vectors (OpenAI or local) | Local = free, OpenAI = better quality |
| Cosine Similarity | Measures vector angle (0-1 scale) | Higher = more similar queries |
| Similarity Threshold | Minimum score to count as cache hit | 0.92-0.95 balances hits vs accuracy |
| TTL (Time-to-Live) | How long entries stay in cache | Prevents stale data, manages memory |
| Cost Tracking | Record tokens saved per cache hit | Prove ROI of caching layer |
| Hit Rate | % of requests served from cache | Target 50-80% for good savings |
| Cache Warming | Pre-populate with common queries | Improve cold-start performance |
Next Steps
- Monitoring Dashboard - Visualize cache performance
- A/B Testing - Test different thresholds