Production Embedding Pipeline
Build a scalable embedding generation system with caching, batching, and monitoring
Production Embedding Pipeline
TL;DR
Build production-ready embedding infrastructure with Redis caching (80%+ cache hit rate), task queues for batch processing, rate limiting for API protection, and Prometheus/Grafana for monitoring latency, throughput, and errors.
| Difficulty | Advanced |
| Time | ~3 days |
| Code | ~600 LOC |
| Prerequisites | Semantic Search Engine, Redis basics, Docker |
Why Embedding Pipelines Need Infrastructure
A notebook that generates embeddings is not a production system. In production, you face repeated re-computation of the same text (wasting GPU cycles and API credits), traffic spikes that overwhelm your model, no visibility into latency or error rates, and no protection against abusive clients. Caching, batching, rate limiting, and monitoring transform a fragile script into a reliable service that handles thousands of requests per second at predictable latency.
What You'll Learn
- Batch embedding generation for large datasets
- Caching strategies for low latency
- Rate limiting and error handling
- Monitoring and observability
Tech Stack
| Component | Technology | Why |
|---|---|---|
| Embeddings | OpenAI / sentence-transformers | Flexible: local models or API-based |
| Queue | Redis / Celery | Async batch processing for high throughput |
| Cache | Redis | Sub-millisecond lookups, 80%+ hit rate target |
| Monitoring | Prometheus + Grafana | Track latency p50/p95/p99, cache hit rate, errors |
| API | FastAPI | Async endpoints, dependency injection, auto-docs |
Architecture
Production Embedding Architecture
Clients
API Layer
Processing
Storage
Infrastructure
Request Flow
Project Structure
embedding-pipeline/
├── src/
│ ├── __init__.py
│ ├── config.py # Configuration
│ ├── embeddings.py # Embedding generation
│ ├── cache.py # Caching layer
│ ├── queue.py # Task queue
│ ├── rate_limiter.py # Rate limiting
│ ├── metrics.py # Prometheus metrics
│ └── api.py # FastAPI application
├── workers/
│ └── embedding_worker.py
├── docker-compose.yml
├── requirements.txt
└── README.mdImplementation
Step 1: Configuration
"""
Configuration management for embedding pipeline.
"""
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""Application settings."""
# API
api_host: str = "0.0.0.0"
api_port: int = 8000
# Embedding model
embedding_model: str = "all-MiniLM-L6-v2"
embedding_dim: int = 384
batch_size: int = 32
max_text_length: int = 512
# Redis
redis_url: str = "redis://localhost:6379"
cache_ttl: int = 86400 # 24 hours
# Rate limiting
rate_limit_requests: int = 100
rate_limit_window: int = 60 # seconds
# OpenAI (optional)
openai_api_key: str = ""
use_openai: bool = False
class Config:
env_file = ".env"
@lru_cache
def get_settings() -> Settings:
return Settings()Step 2: Embedding Service
"""
Embedding generation service with batching support.
"""
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import Optional
import hashlib
import asyncio
from concurrent.futures import ThreadPoolExecutor
from .config import get_settings
class EmbeddingService:
"""
Production embedding service with batching and async support.
"""
def __init__(self):
settings = get_settings()
self.model = SentenceTransformer(settings.embedding_model)
self.batch_size = settings.batch_size
self.max_length = settings.max_text_length
self.executor = ThreadPoolExecutor(max_workers=4)
def _truncate_text(self, text: str) -> str:
"""Truncate text to max length."""
if len(text) > self.max_length:
return text[:self.max_length]
return text
def _compute_hash(self, text: str) -> str:
"""Compute hash for caching."""
return hashlib.sha256(text.encode()).hexdigest()[:16]
def embed_single(self, text: str) -> np.ndarray:
"""Generate embedding for a single text."""
text = self._truncate_text(text)
return self.model.encode(
text,
convert_to_numpy=True,
normalize_embeddings=True
)
def embed_batch(self, texts: list[str]) -> np.ndarray:
"""
Generate embeddings for a batch of texts.
More efficient than single embeddings for multiple texts.
"""
texts = [self._truncate_text(t) for t in texts]
return self.model.encode(
texts,
convert_to_numpy=True,
normalize_embeddings=True,
batch_size=self.batch_size,
show_progress_bar=len(texts) > 100
)
async def embed_async(self, text: str) -> np.ndarray:
"""Async embedding generation."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.embed_single,
text
)
async def embed_batch_async(self, texts: list[str]) -> np.ndarray:
"""Async batch embedding generation."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self.embed_batch,
texts
)
class OpenAIEmbeddingService:
"""
OpenAI embedding service with rate limiting and retries.
"""
def __init__(self):
import openai
settings = get_settings()
self.client = openai.OpenAI(api_key=settings.openai_api_key)
self.model = "text-embedding-3-small"
def embed_single(self, text: str) -> np.ndarray:
"""Generate embedding using OpenAI."""
response = self.client.embeddings.create(
input=text,
model=self.model
)
return np.array(response.data[0].embedding)
def embed_batch(self, texts: list[str]) -> np.ndarray:
"""Generate batch embeddings using OpenAI."""
response = self.client.embeddings.create(
input=texts,
model=self.model
)
return np.array([d.embedding for d in response.data])
# Factory function
def get_embedding_service():
settings = get_settings()
if settings.use_openai and settings.openai_api_key:
return OpenAIEmbeddingService()
return EmbeddingService()Step 3: Caching Layer
"""
Redis-based caching for embeddings.
"""
import numpy as np
import redis
import json
import hashlib
from typing import Optional
from .config import get_settings
class EmbeddingCache:
"""
Redis cache for embeddings.
Stores embeddings as JSON-encoded lists for simplicity.
For production with high throughput, consider using
Redis with numpy serialization or a dedicated vector cache.
"""
def __init__(self):
settings = get_settings()
self.redis = redis.from_url(settings.redis_url)
self.ttl = settings.cache_ttl
self.prefix = "emb:"
def _key(self, text: str) -> str:
"""Generate cache key from text."""
text_hash = hashlib.sha256(text.encode()).hexdigest()[:16]
return f"{self.prefix}{text_hash}"
def get(self, text: str) -> Optional[np.ndarray]:
"""Get embedding from cache."""
key = self._key(text)
data = self.redis.get(key)
if data:
return np.array(json.loads(data))
return None
def set(self, text: str, embedding: np.ndarray) -> None:
"""Store embedding in cache."""
key = self._key(text)
data = json.dumps(embedding.tolist())
self.redis.setex(key, self.ttl, data)
def get_many(self, texts: list[str]) -> dict[str, Optional[np.ndarray]]:
"""Get multiple embeddings from cache."""
keys = [self._key(t) for t in texts]
values = self.redis.mget(keys)
results = {}
for text, value in zip(texts, values):
if value:
results[text] = np.array(json.loads(value))
else:
results[text] = None
return results
def set_many(self, embeddings: dict[str, np.ndarray]) -> None:
"""Store multiple embeddings in cache."""
pipe = self.redis.pipeline()
for text, embedding in embeddings.items():
key = self._key(text)
data = json.dumps(embedding.tolist())
pipe.setex(key, self.ttl, data)
pipe.execute()
def invalidate(self, text: str) -> None:
"""Remove embedding from cache."""
key = self._key(text)
self.redis.delete(key)
def stats(self) -> dict:
"""Get cache statistics."""
info = self.redis.info("stats")
keys = self.redis.keys(f"{self.prefix}*")
return {
"cached_embeddings": len(keys),
"hits": info.get("keyspace_hits", 0),
"misses": info.get("keyspace_misses", 0),
"hit_rate": info.get("keyspace_hits", 0) /
max(1, info.get("keyspace_hits", 0) + info.get("keyspace_misses", 0))
}
class CachedEmbeddingService:
"""
Embedding service with caching.
"""
def __init__(self, embedding_service, cache: EmbeddingCache):
self.embedder = embedding_service
self.cache = cache
def embed(self, text: str) -> np.ndarray:
"""Get embedding with cache lookup."""
# Check cache
cached = self.cache.get(text)
if cached is not None:
return cached
# Generate and cache
embedding = self.embedder.embed_single(text)
self.cache.set(text, embedding)
return embedding
def embed_batch(self, texts: list[str]) -> np.ndarray:
"""
Get embeddings with batch cache lookup.
Only computes embeddings for cache misses.
"""
# Check cache for all texts
cached = self.cache.get_many(texts)
# Separate hits and misses
hits = {t: e for t, e in cached.items() if e is not None}
misses = [t for t in texts if cached[t] is None]
# Compute missing embeddings
if misses:
new_embeddings = self.embedder.embed_batch(misses)
# Cache new embeddings
new_cache = {t: e for t, e in zip(misses, new_embeddings)}
self.cache.set_many(new_cache)
hits.update(new_cache)
# Return in original order
return np.array([hits[t] for t in texts])Understanding the Cache-Aware Batch Flow:
┌────────────────────────────────────────────────────────────────┐
│ embed_batch(["text_A", "text_B", "text_C", "text_D"]) │
│ │
│ 1. cache.get_many() ──► Redis MGET │
│ Result: text_A=HIT, text_B=MISS, text_C=HIT, text_D=MISS │
│ │
│ 2. Only compute misses: │
│ embedder.embed_batch(["text_B", "text_D"]) ──► 2 vectors │
│ │
│ 3. cache.set_many({"text_B": emb_B, "text_D": emb_D}) │
│ Pipeline SETEX for atomic batch write │
│ │
│ 4. Return [emb_A, emb_B, emb_C, emb_D] in original order │
└────────────────────────────────────────────────────────────────┘| Design Decision | Why |
|---|---|
| SHA-256 text hash as cache key | Deterministic, fixed-length, avoids key-length limits |
| JSON-encoded embeddings | Human-readable, debuggable; use msgpack for higher throughput |
| Pipeline for batch writes | Single round-trip to Redis instead of N individual calls |
| TTL of 24 hours | Balance freshness vs cache hit rate; adjust per use case |
Step 4: Rate Limiting
"""
Rate limiting for API protection.
"""
import redis
import time
from typing import Optional
from .config import get_settings
class RateLimiter:
"""
Token bucket rate limiter using Redis.
"""
def __init__(self, key_prefix: str = "rate:"):
settings = get_settings()
self.redis = redis.from_url(settings.redis_url)
self.prefix = key_prefix
self.max_requests = settings.rate_limit_requests
self.window = settings.rate_limit_window
def _key(self, identifier: str) -> str:
return f"{self.prefix}{identifier}"
def is_allowed(self, identifier: str) -> tuple[bool, dict]:
"""
Check if request is allowed under rate limit.
Args:
identifier: Unique identifier (IP, API key, user ID)
Returns:
Tuple of (allowed, info_dict)
"""
key = self._key(identifier)
now = time.time()
window_start = now - self.window
pipe = self.redis.pipeline()
# Remove old entries
pipe.zremrangebyscore(key, 0, window_start)
# Count current requests
pipe.zcard(key)
# Add new request
pipe.zadd(key, {str(now): now})
# Set expiry
pipe.expire(key, self.window)
results = pipe.execute()
current_count = results[1]
allowed = current_count < self.max_requests
return allowed, {
"remaining": max(0, self.max_requests - current_count - 1),
"limit": self.max_requests,
"reset": int(now + self.window)
}
def get_wait_time(self, identifier: str) -> float:
"""Get time to wait before next allowed request."""
key = self._key(identifier)
now = time.time()
window_start = now - self.window
# Get oldest request in window
oldest = self.redis.zrangebyscore(
key, window_start, now, start=0, num=1, withscores=True
)
if not oldest:
return 0
oldest_time = oldest[0][1]
wait = (oldest_time + self.window) - now
return max(0, wait)
class RateLimitMiddleware:
"""
FastAPI middleware for rate limiting.
"""
def __init__(self):
self.limiter = RateLimiter()
async def __call__(self, request, call_next):
from fastapi import HTTPException
from fastapi.responses import JSONResponse
# Get identifier (IP or API key)
identifier = request.client.host
if "x-api-key" in request.headers:
identifier = request.headers["x-api-key"]
allowed, info = self.limiter.is_allowed(identifier)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"retry_after": info["reset"]
},
headers={
"X-RateLimit-Limit": str(info["limit"]),
"X-RateLimit-Remaining": str(info["remaining"]),
"X-RateLimit-Reset": str(info["reset"])
}
)
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(info["limit"])
response.headers["X-RateLimit-Remaining"] = str(info["remaining"])
response.headers["X-RateLimit-Reset"] = str(info["reset"])
return responseStep 5: Prometheus Metrics
"""
Prometheus metrics for monitoring.
"""
from prometheus_client import Counter, Histogram, Gauge, generate_latest
import time
from functools import wraps
# Counters
EMBEDDING_REQUESTS = Counter(
"embedding_requests_total",
"Total embedding requests",
["status", "cached"]
)
# Histograms
EMBEDDING_LATENCY = Histogram(
"embedding_latency_seconds",
"Embedding generation latency",
["model"],
buckets=[0.01, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0]
)
BATCH_SIZE = Histogram(
"embedding_batch_size",
"Size of embedding batches",
buckets=[1, 5, 10, 25, 50, 100, 250, 500]
)
# Gauges
CACHE_HIT_RATE = Gauge(
"embedding_cache_hit_rate",
"Cache hit rate"
)
QUEUE_SIZE = Gauge(
"embedding_queue_size",
"Number of pending embedding tasks"
)
def track_latency(model_name: str = "default"):
"""Decorator to track embedding latency."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
try:
result = func(*args, **kwargs)
EMBEDDING_REQUESTS.labels(status="success", cached="false").inc()
return result
except Exception as e:
EMBEDDING_REQUESTS.labels(status="error", cached="false").inc()
raise
finally:
EMBEDDING_LATENCY.labels(model=model_name).observe(time.time() - start)
return wrapper
return decorator
def track_async_latency(model_name: str = "default"):
"""Decorator to track async embedding latency."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
start = time.time()
try:
result = await func(*args, **kwargs)
EMBEDDING_REQUESTS.labels(status="success", cached="false").inc()
return result
except Exception as e:
EMBEDDING_REQUESTS.labels(status="error", cached="false").inc()
raise
finally:
EMBEDDING_LATENCY.labels(model=model_name).observe(time.time() - start)
return wrapper
return decoratorStep 6: FastAPI Application
"""
FastAPI application for embedding service.
"""
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, Field
from prometheus_client import generate_latest
from starlette.responses import Response
from typing import Optional
from .embeddings import get_embedding_service, EmbeddingService
from .cache import EmbeddingCache, CachedEmbeddingService
from .rate_limiter import RateLimiter
from .metrics import (
EMBEDDING_REQUESTS, BATCH_SIZE, CACHE_HIT_RATE,
track_latency, track_async_latency
)
from .config import get_settings
app = FastAPI(
title="Embedding Pipeline API",
description="Production embedding service with caching",
version="1.0.0"
)
# Initialize services
embedding_service = get_embedding_service()
cache = EmbeddingCache()
cached_service = CachedEmbeddingService(embedding_service, cache)
rate_limiter = RateLimiter()
# Request/Response models
class EmbedRequest(BaseModel):
text: str = Field(..., max_length=10000)
class EmbedBatchRequest(BaseModel):
texts: list[str] = Field(..., max_length=100)
class EmbedResponse(BaseModel):
embedding: list[float]
cached: bool
model: str
class EmbedBatchResponse(BaseModel):
embeddings: list[list[float]]
count: int
cached_count: int
# Dependency for rate limiting
async def check_rate_limit(request):
identifier = request.client.host
allowed, info = rate_limiter.is_allowed(identifier)
if not allowed:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers={
"Retry-After": str(info["reset"]),
"X-RateLimit-Remaining": str(info["remaining"])
}
)
return info
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "healthy"}
@app.post("/embed", response_model=EmbedResponse)
async def embed_text(
request: EmbedRequest,
rate_info: dict = Depends(check_rate_limit)
):
"""
Generate embedding for a single text.
Uses caching for repeated requests.
"""
settings = get_settings()
# Check cache first
cached_embedding = cache.get(request.text)
if cached_embedding is not None:
EMBEDDING_REQUESTS.labels(status="success", cached="true").inc()
return EmbedResponse(
embedding=cached_embedding.tolist(),
cached=True,
model=settings.embedding_model
)
# Generate embedding
embedding = cached_service.embed(request.text)
return EmbedResponse(
embedding=embedding.tolist(),
cached=False,
model=settings.embedding_model
)
@app.post("/embed/batch", response_model=EmbedBatchResponse)
async def embed_batch(
request: EmbedBatchRequest,
rate_info: dict = Depends(check_rate_limit)
):
"""
Generate embeddings for multiple texts.
More efficient than individual requests.
"""
BATCH_SIZE.observe(len(request.texts))
# Check cache for all
cached = cache.get_many(request.texts)
cached_count = sum(1 for v in cached.values() if v is not None)
# Generate all embeddings (cache-aware)
embeddings = cached_service.embed_batch(request.texts)
return EmbedBatchResponse(
embeddings=embeddings.tolist(),
count=len(request.texts),
cached_count=cached_count
)
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint."""
# Update cache metrics
stats = cache.stats()
CACHE_HIT_RATE.set(stats["hit_rate"])
return Response(
content=generate_latest(),
media_type="text/plain"
)
@app.get("/stats")
async def stats():
"""Get service statistics."""
cache_stats = cache.stats()
settings = get_settings()
return {
"model": settings.embedding_model,
"embedding_dim": settings.embedding_dim,
"cache": cache_stats
}Step 7: Docker Compose
version: "3.8"
services:
api:
build: .
ports:
- "8000:8000"
environment:
- REDIS_URL=redis://redis:6379
- EMBEDDING_MODEL=all-MiniLM-L6-v2
depends_on:
- redis
deploy:
replicas: 2
resources:
limits:
memory: 2G
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
grafana:
image: grafana/grafana:latest
ports:
- "3001:3000"
environment:
- GF_SECURITY_ADMIN_PASSWORD=admin
volumes:
- grafana_data:/var/lib/grafana
volumes:
redis_data:
grafana_data:FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY src/ src/
CMD ["uvicorn", "src.api:app", "--host", "0.0.0.0", "--port", "8000"]Running the Pipeline
# Start all services
docker-compose up -d
# Test the API
curl -X POST "http://localhost:8000/embed" \
-H "Content-Type: application/json" \
-d '{"text": "Hello, world!"}'
# Batch request
curl -X POST "http://localhost:8000/embed/batch" \
-H "Content-Type: application/json" \
-d '{"texts": ["Hello", "World", "Test"]}'
# Check metrics
curl http://localhost:8000/metricsKey Concepts
Caching Strategy
| Strategy | Latency | Cost | Best For |
|---|---|---|---|
| No cache | High | High | Development |
| Memory cache | Very low | Medium | Single instance |
| Redis cache | Low | Low | Distributed systems |
Monitoring Dashboard
Key metrics to track:
- Request latency (p50, p95, p99)
- Cache hit rate (target: >80%)
- Error rate (target: less than 1%)
- Throughput (requests/second)
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Embedding Cache | Store computed embeddings by text hash | Avoid re-computing same text, reduce latency |
| Task Queue | Async job processing (Celery/Redis) | Handle bursts, batch for efficiency |
| Rate Limiting | Cap requests per client | Protect from abuse, ensure fairness |
| Batch Processing | Group multiple texts per API call | Better GPU utilization, lower per-item cost |
| Health Checks | /health and /ready endpoints | Enable load balancer routing |
| Prometheus Metrics | Latency, throughput, error counters | Enable alerting and debugging |
| Cache Hit Rate | % of requests served from cache | Target 80%+, indicates good caching |
Next Steps
- Multi-Modal Embeddings - Combine text and images
- Search at Scale - Billion-vector search