Production RAG Pipeline
Build a production-ready RAG system with evaluation, monitoring, caching, and A/B testing
Production RAG Pipeline
| Property | Value |
|---|---|
| Difficulty | 🔴 Advanced |
| Time | ~3 days |
| Code Size | ~800 LOC |
| Prerequisites | All intermediate RAG projects |
Tech Stack
| Technology | Purpose |
|---|---|
| LangChain | RAG orchestration |
| OpenAI | GPT-4 + Embeddings |
| RAGAS | RAG evaluation metrics |
| LangSmith | Observability and tracing |
| Redis | Semantic caching |
| ChromaDB | Vector storage |
| FastAPI | Production API |
| Prometheus | Metrics collection |
| PostgreSQL | A/B test results storage |
Prerequisites
- Completed all intermediate RAG projects
- Python 3.10+ with async/await expertise
- OpenAI API key (Get one here)
- LangSmith API key (Get one here)
- Docker for local services (Redis, PostgreSQL, Prometheus)
- Understanding of production system design patterns
What You'll Learn
- Evaluate RAG quality with RAGAS metrics (faithfulness, relevance, context precision)
- Implement comprehensive observability with tracing and logging
- Build semantic caching for latency reduction and cost savings
- Design A/B testing framework for retrieval experiments
- Apply production patterns: circuit breakers, rate limiting, graceful degradation
- Monitor and alert on RAG performance in real-time
Why Production RAG is Different
Moving from prototype to production requires addressing critical concerns:
The Production Gap
| Aspect | Prototype | Production |
|---|---|---|
| Error Handling | Print and continue | Structured errors, retries, fallbacks |
| Performance | Single request | Concurrent users, caching, optimization |
| Quality | "Looks good" | Quantitative metrics, regression tests |
| Monitoring | Console logs | Distributed tracing, dashboards, alerts |
| Experiments | Code changes | A/B tests, gradual rollouts |
System Architecture
Request Flow
Project Structure
production-rag/
├── src/
│ ├── __init__.py
│ ├── config.py
│ ├── core/
│ │ ├── __init__.py
│ │ ├── pipeline.py # Main RAG pipeline
│ │ ├── retriever.py # Retrieval strategies
│ │ └── generator.py # Response generation
│ ├── evaluation/
│ │ ├── __init__.py
│ │ ├── metrics.py # RAGAS metrics
│ │ ├── evaluator.py # Online evaluation
│ │ └── ground_truth.py # Test dataset management
│ ├── observability/
│ │ ├── __init__.py
│ │ ├── tracer.py # LangSmith integration
│ │ ├── metrics.py # Prometheus metrics
│ │ └── logging.py # Structured logging
│ ├── middleware/
│ │ ├── __init__.py
│ │ ├── cache.py # Semantic caching
│ │ ├── rate_limiter.py # Rate limiting
│ │ └── circuit_breaker.py # Circuit breaker
│ ├── experiments/
│ │ ├── __init__.py
│ │ ├── ab_router.py # A/B test routing
│ │ └── experiment.py # Experiment management
│ └── api.py # FastAPI application
├── tests/
│ ├── test_evaluation.py
│ ├── test_cache.py
│ └── test_experiments.py
├── evaluation_data/
│ └── ground_truth.json # Test questions and answers
├── docker-compose.yml
├── prometheus.yml
├── .env
├── pyproject.toml
└── README.mdImplementation
Step 1: Project Setup
mkdir production-rag && cd production-rag
uv init
uv venv && source .venv/bin/activateuv add langchain langchain-openai langchain-chroma langchain-community
uv add chromadb openai tiktoken
uv add ragas datasets # RAG evaluation
uv add langsmith # Observability
uv add redis aioredis # Caching
uv add prometheus-client # Metrics
uv add fastapi uvicorn httpx # API
uv add sqlalchemy asyncpg # A/B test storage
uv add python-dotenv pydantic-settings
uv add tenacity # Retries# API Keys
OPENAI_API_KEY=sk-your-key-here
LANGSMITH_API_KEY=lsv2_your-key-here
LANGSMITH_PROJECT=production-rag
# Storage
CHROMA_PERSIST_DIR=./chroma_db
REDIS_URL=redis://localhost:6379
DATABASE_URL=postgresql+asyncpg://user:pass@localhost:5432/rag_experiments
# Feature Flags
ENABLE_CACHE=true
ENABLE_TRACING=true
ENABLE_AB_TESTING=true
# Rate Limiting
RATE_LIMIT_REQUESTS=100
RATE_LIMIT_WINDOW=60
# Models
LLM_MODEL=gpt-4o-mini
EMBEDDING_MODEL=text-embedding-3-smallStep 2: Configuration
"""Production configuration with validation."""
from pydantic_settings import BaseSettings
from pydantic import Field
from functools import lru_cache
class Settings(BaseSettings):
"""Application settings with environment variable support."""
# API Keys
openai_api_key: str = Field(..., env="OPENAI_API_KEY")
langsmith_api_key: str = Field(default="", env="LANGSMITH_API_KEY")
langsmith_project: str = Field(default="production-rag", env="LANGSMITH_PROJECT")
# Storage
chroma_persist_dir: str = Field(default="./chroma_db", env="CHROMA_PERSIST_DIR")
redis_url: str = Field(default="redis://localhost:6379", env="REDIS_URL")
database_url: str = Field(default="", env="DATABASE_URL")
# Feature Flags
enable_cache: bool = Field(default=True, env="ENABLE_CACHE")
enable_tracing: bool = Field(default=True, env="ENABLE_TRACING")
enable_ab_testing: bool = Field(default=False, env="ENABLE_AB_TESTING")
# Rate Limiting
rate_limit_requests: int = Field(default=100, env="RATE_LIMIT_REQUESTS")
rate_limit_window: int = Field(default=60, env="RATE_LIMIT_WINDOW")
# Models
llm_model: str = Field(default="gpt-4o-mini", env="LLM_MODEL")
embedding_model: str = Field(default="text-embedding-3-small", env="EMBEDDING_MODEL")
# RAG Settings
retrieval_top_k: int = 4
chunk_size: int = 1000
chunk_overlap: int = 200
# Cache Settings
cache_ttl_seconds: int = 3600
cache_similarity_threshold: float = 0.95
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
@lru_cache
def get_settings() -> Settings:
"""Get cached settings instance."""
return Settings()
settings = get_settings()Step 3: RAG Evaluation Module
The heart of production RAG is measuring quality. RAGAS provides four key metrics:
| Metric | Measures | Range | Ideal |
|---|---|---|---|
| Faithfulness | Is the answer grounded in context? | 0-1 | > 0.8 |
| Answer Relevancy | Does the answer address the question? | 0-1 | > 0.8 |
| Context Precision | Are retrieved chunks relevant? | 0-1 | > 0.7 |
| Context Recall | Did we retrieve all needed info? | 0-1 | > 0.7 |
"""RAG evaluation metrics using RAGAS."""
from dataclasses import dataclass
from typing import List, Optional
import asyncio
from ragas import evaluate
from ragas.metrics import (
faithfulness,
answer_relevancy,
context_precision,
context_recall,
)
from datasets import Dataset
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from src.config import settings
@dataclass
class EvaluationResult:
"""Container for evaluation scores."""
faithfulness: float
answer_relevancy: float
context_precision: float
context_recall: Optional[float] = None # Requires ground truth
@property
def overall_score(self) -> float:
"""Calculate weighted overall score."""
scores = [self.faithfulness, self.answer_relevancy, self.context_precision]
if self.context_recall is not None:
scores.append(self.context_recall)
return sum(scores) / len(scores)
def to_dict(self) -> dict:
return {
"faithfulness": self.faithfulness,
"answer_relevancy": self.answer_relevancy,
"context_precision": self.context_precision,
"context_recall": self.context_recall,
"overall_score": self.overall_score,
}
@dataclass
class RAGSample:
"""A single RAG evaluation sample."""
question: str
answer: str
contexts: List[str]
ground_truth: Optional[str] = None
class RAGASEvaluator:
"""Evaluate RAG responses using RAGAS metrics."""
def __init__(self):
self.llm = ChatOpenAI(
model=settings.llm_model,
openai_api_key=settings.openai_api_key,
temperature=0,
)
self.embeddings = OpenAIEmbeddings(
model=settings.embedding_model,
openai_api_key=settings.openai_api_key,
)
# Select metrics based on whether we have ground truth
self.metrics_with_gt = [
faithfulness,
answer_relevancy,
context_precision,
context_recall,
]
self.metrics_without_gt = [
faithfulness,
answer_relevancy,
context_precision,
]
def evaluate_single(self, sample: RAGSample) -> EvaluationResult:
"""Evaluate a single RAG response."""
# Prepare dataset
data = {
"question": [sample.question],
"answer": [sample.answer],
"contexts": [sample.contexts],
}
metrics = self.metrics_without_gt
if sample.ground_truth:
data["ground_truth"] = [sample.ground_truth]
metrics = self.metrics_with_gt
dataset = Dataset.from_dict(data)
# Run evaluation
results = evaluate(
dataset,
metrics=metrics,
llm=self.llm,
embeddings=self.embeddings,
)
return EvaluationResult(
faithfulness=results["faithfulness"],
answer_relevancy=results["answer_relevancy"],
context_precision=results["context_precision"],
context_recall=results.get("context_recall"),
)
def evaluate_batch(self, samples: List[RAGSample]) -> List[EvaluationResult]:
"""Evaluate multiple RAG responses."""
results = []
for sample in samples:
result = self.evaluate_single(sample)
results.append(result)
return results
def evaluate_from_ground_truth(
self,
ground_truth_path: str,
rag_pipeline,
) -> dict:
"""Run evaluation against a ground truth dataset."""
import json
with open(ground_truth_path) as f:
test_data = json.load(f)
samples = []
for item in test_data["questions"]:
# Get RAG response
response = rag_pipeline.query(item["question"])
samples.append(RAGSample(
question=item["question"],
answer=response["answer"],
contexts=response["contexts"],
ground_truth=item.get("expected_answer"),
))
results = self.evaluate_batch(samples)
# Aggregate results
return {
"num_samples": len(results),
"avg_faithfulness": sum(r.faithfulness for r in results) / len(results),
"avg_relevancy": sum(r.answer_relevancy for r in results) / len(results),
"avg_precision": sum(r.context_precision for r in results) / len(results),
"avg_overall": sum(r.overall_score for r in results) / len(results),
"details": [r.to_dict() for r in results],
}Step 4: Online Evaluator
For production, we need lightweight online evaluation that doesn't add significant latency:
"""Lightweight online evaluation for production."""
import hashlib
import random
from typing import Optional, Dict, Any
from datetime import datetime
import asyncio
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from src.config import settings
from src.observability.metrics import rag_metrics
class OnlineEvaluator:
"""Fast online evaluation with sampling.
Performs lightweight quality checks on a sample of requests
to avoid adding latency to every request.
"""
def __init__(self, sample_rate: float = 0.1):
"""
Args:
sample_rate: Fraction of requests to evaluate (0.0 to 1.0)
"""
self.sample_rate = sample_rate
self.llm = ChatOpenAI(
model=settings.llm_model,
openai_api_key=settings.openai_api_key,
temperature=0,
max_tokens=100, # Keep evaluation fast
)
# Lightweight faithfulness check prompt
self.faithfulness_prompt = ChatPromptTemplate.from_messages([
("system", """You are evaluating if an answer is faithful to the given context.
Score from 0 to 1 where:
- 1.0 = Answer is fully supported by context
- 0.5 = Answer is partially supported
- 0.0 = Answer contains claims not in context
ONLY output a number between 0 and 1. Nothing else."""),
("human", """Context:
{context}
Answer:
{answer}
Faithfulness score:"""),
])
# Relevancy check prompt
self.relevancy_prompt = ChatPromptTemplate.from_messages([
("system", """You are evaluating if an answer is relevant to the question.
Score from 0 to 1 where:
- 1.0 = Directly and completely answers the question
- 0.5 = Partially answers the question
- 0.0 = Does not address the question
ONLY output a number between 0 and 1. Nothing else."""),
("human", """Question: {question}
Answer: {answer}
Relevancy score:"""),
])
def should_evaluate(self) -> bool:
"""Determine if this request should be evaluated."""
return random.random() < self.sample_rate
async def evaluate_async(
self,
question: str,
answer: str,
contexts: list[str],
) -> Optional[Dict[str, float]]:
"""Perform async evaluation without blocking response."""
if not self.should_evaluate():
return None
try:
context_text = "\n\n".join(contexts[:3]) # Limit context for speed
# Run both evaluations concurrently
faithfulness_task = self._evaluate_faithfulness(context_text, answer)
relevancy_task = self._evaluate_relevancy(question, answer)
faithfulness, relevancy = await asyncio.gather(
faithfulness_task,
relevancy_task,
return_exceptions=True,
)
# Handle any exceptions
if isinstance(faithfulness, Exception):
faithfulness = None
if isinstance(relevancy, Exception):
relevancy = None
scores = {
"faithfulness": faithfulness,
"relevancy": relevancy,
"timestamp": datetime.utcnow().isoformat(),
}
# Record metrics
if faithfulness is not None:
rag_metrics.record_faithfulness(faithfulness)
if relevancy is not None:
rag_metrics.record_relevancy(relevancy)
return scores
except Exception as e:
# Don't let evaluation failures affect the main request
return {"error": str(e)}
async def _evaluate_faithfulness(self, context: str, answer: str) -> float:
"""Evaluate faithfulness score."""
chain = self.faithfulness_prompt | self.llm
result = await chain.ainvoke({"context": context, "answer": answer})
return self._parse_score(result.content)
async def _evaluate_relevancy(self, question: str, answer: str) -> float:
"""Evaluate relevancy score."""
chain = self.relevancy_prompt | self.llm
result = await chain.ainvoke({"question": question, "answer": answer})
return self._parse_score(result.content)
def _parse_score(self, content: str) -> float:
"""Parse score from LLM response."""
try:
score = float(content.strip())
return max(0.0, min(1.0, score)) # Clamp to [0, 1]
except ValueError:
return 0.5 # Default on parse failureStep 5: Observability with LangSmith
"""LangSmith integration for distributed tracing."""
import os
from typing import Optional, Any, Dict
from contextlib import contextmanager
from functools import wraps
import time
from langsmith import Client
from langsmith.run_trees import RunTree
from src.config import settings
class LangSmithTracer:
"""Comprehensive tracing for RAG pipelines."""
def __init__(self):
self.enabled = settings.enable_tracing and settings.langsmith_api_key
self.client = None
if self.enabled:
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = settings.langsmith_api_key
os.environ["LANGCHAIN_PROJECT"] = settings.langsmith_project
self.client = Client()
@contextmanager
def trace_query(
self,
query: str,
session_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
):
"""Context manager for tracing a complete query."""
if not self.enabled:
yield None
return
run_tree = RunTree(
name="rag_query",
run_type="chain",
inputs={"query": query},
extra={
"metadata": {
"session_id": session_id,
**(metadata or {}),
}
},
)
try:
yield run_tree
run_tree.end()
run_tree.post()
except Exception as e:
run_tree.end(error=str(e))
run_tree.post()
raise
def trace_retrieval(
self,
parent_run: Optional[RunTree],
query: str,
results: list,
retriever_name: str = "default",
):
"""Trace retrieval step."""
if not self.enabled or not parent_run:
return
child_run = parent_run.create_child(
name=f"retrieval_{retriever_name}",
run_type="retriever",
inputs={"query": query},
)
child_run.end(outputs={
"documents": [doc.page_content[:200] for doc in results],
"num_results": len(results),
})
child_run.post()
def trace_generation(
self,
parent_run: Optional[RunTree],
prompt: str,
response: str,
model: str,
tokens_used: Optional[int] = None,
):
"""Trace generation step."""
if not self.enabled or not parent_run:
return
child_run = parent_run.create_child(
name="generation",
run_type="llm",
inputs={"prompt": prompt[:500]}, # Truncate for readability
)
child_run.end(outputs={
"response": response,
"model": model,
"tokens_used": tokens_used,
})
child_run.post()
def trace_evaluation(
self,
parent_run: Optional[RunTree],
scores: Dict[str, float],
):
"""Trace evaluation scores."""
if not self.enabled or not parent_run:
return
child_run = parent_run.create_child(
name="evaluation",
run_type="tool",
inputs={"action": "evaluate_response"},
)
child_run.end(outputs=scores)
child_run.post()
# Decorator for easy tracing
def traced(name: str = None):
"""Decorator to trace function execution."""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
tracer = LangSmithTracer()
func_name = name or func.__name__
with tracer.trace_query(
query=str(kwargs.get("query", args[0] if args else "")),
metadata={"function": func_name},
):
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
tracer = LangSmithTracer()
func_name = name or func.__name__
with tracer.trace_query(
query=str(kwargs.get("query", args[0] if args else "")),
metadata={"function": func_name},
):
return func(*args, **kwargs)
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decorator
import asyncioStep 6: Prometheus Metrics
"""Prometheus metrics for monitoring."""
from prometheus_client import Counter, Histogram, Gauge, generate_latest
import time
from functools import wraps
class RAGMetrics:
"""Centralized metrics collection for RAG pipeline."""
def __init__(self):
# Request metrics
self.requests_total = Counter(
"rag_requests_total",
"Total RAG requests",
["status", "variant"],
)
self.request_latency = Histogram(
"rag_request_latency_seconds",
"Request latency in seconds",
["stage"],
buckets=[0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0],
)
# Retrieval metrics
self.retrieval_count = Histogram(
"rag_retrieval_document_count",
"Number of documents retrieved",
buckets=[1, 2, 3, 4, 5, 10, 20],
)
# Cache metrics
self.cache_hits = Counter(
"rag_cache_hits_total",
"Cache hit count",
)
self.cache_misses = Counter(
"rag_cache_misses_total",
"Cache miss count",
)
# Quality metrics
self.faithfulness_score = Histogram(
"rag_faithfulness_score",
"Faithfulness scores from evaluation",
buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
)
self.relevancy_score = Histogram(
"rag_relevancy_score",
"Answer relevancy scores",
buckets=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
)
# Token usage
self.tokens_used = Counter(
"rag_tokens_total",
"Total tokens used",
["type"], # prompt, completion
)
# Circuit breaker state
self.circuit_state = Gauge(
"rag_circuit_breaker_state",
"Circuit breaker state (0=closed, 1=open, 2=half-open)",
)
# A/B test metrics
self.ab_test_requests = Counter(
"rag_ab_test_requests_total",
"A/B test requests by variant",
["experiment", "variant"],
)
def record_request(self, status: str, variant: str = "default"):
"""Record a request."""
self.requests_total.labels(status=status, variant=variant).inc()
def record_latency(self, stage: str, duration: float):
"""Record latency for a pipeline stage."""
self.request_latency.labels(stage=stage).observe(duration)
def record_retrieval(self, doc_count: int):
"""Record retrieval results."""
self.retrieval_count.observe(doc_count)
def record_cache_hit(self):
"""Record a cache hit."""
self.cache_hits.inc()
def record_cache_miss(self):
"""Record a cache miss."""
self.cache_misses.inc()
def record_faithfulness(self, score: float):
"""Record faithfulness score."""
self.faithfulness_score.observe(score)
def record_relevancy(self, score: float):
"""Record relevancy score."""
self.relevancy_score.observe(score)
def record_tokens(self, prompt_tokens: int, completion_tokens: int):
"""Record token usage."""
self.tokens_used.labels(type="prompt").inc(prompt_tokens)
self.tokens_used.labels(type="completion").inc(completion_tokens)
def set_circuit_state(self, state: int):
"""Set circuit breaker state."""
self.circuit_state.set(state)
def record_ab_request(self, experiment: str, variant: str):
"""Record A/B test request."""
self.ab_test_requests.labels(experiment=experiment, variant=variant).inc()
def get_metrics(self) -> bytes:
"""Generate Prometheus metrics output."""
return generate_latest()
# Singleton instance
rag_metrics = RAGMetrics()
# Timing decorator
def timed(stage: str):
"""Decorator to time function execution."""
def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
start = time.perf_counter()
try:
return await func(*args, **kwargs)
finally:
duration = time.perf_counter() - start
rag_metrics.record_latency(stage, duration)
@wraps(func)
def sync_wrapper(*args, **kwargs):
start = time.perf_counter()
try:
return func(*args, **kwargs)
finally:
duration = time.perf_counter() - start
rag_metrics.record_latency(stage, duration)
import asyncio
if asyncio.iscoroutinefunction(func):
return async_wrapper
return sync_wrapper
return decoratorStep 7: Semantic Caching
Semantic caching reduces latency and costs by serving cached responses for similar queries:
"""Semantic caching for RAG responses."""
import json
import hashlib
from typing import Optional, Dict, Any
from datetime import datetime
import numpy as np
import redis.asyncio as redis
from langchain_openai import OpenAIEmbeddings
from src.config import settings
from src.observability.metrics import rag_metrics
class SemanticCache:
"""Cache RAG responses using semantic similarity.
Instead of exact query matching, finds similar queries
that have been answered before.
"""
def __init__(self):
self.enabled = settings.enable_cache
self.redis_client = None
self.embeddings = None
self.ttl = settings.cache_ttl_seconds
self.similarity_threshold = settings.cache_similarity_threshold
if self.enabled:
self.embeddings = OpenAIEmbeddings(
model=settings.embedding_model,
openai_api_key=settings.openai_api_key,
)
async def connect(self):
"""Initialize Redis connection."""
if self.enabled and not self.redis_client:
self.redis_client = await redis.from_url(
settings.redis_url,
decode_responses=True,
)
async def get(self, query: str) -> Optional[Dict[str, Any]]:
"""Get cached response for similar query."""
if not self.enabled:
return None
await self.connect()
try:
# Generate query embedding
query_embedding = await self._embed_async(query)
# Get all cached query embeddings
cache_keys = await self.redis_client.keys("rag:cache:*")
best_match = None
best_similarity = 0
for key in cache_keys:
cached = await self.redis_client.hgetall(key)
if not cached:
continue
cached_embedding = json.loads(cached.get("embedding", "[]"))
similarity = self._cosine_similarity(query_embedding, cached_embedding)
if similarity > best_similarity:
best_similarity = similarity
best_match = cached
if best_similarity >= self.similarity_threshold and best_match:
rag_metrics.record_cache_hit()
return {
"answer": best_match["answer"],
"contexts": json.loads(best_match.get("contexts", "[]")),
"cached": True,
"similarity": best_similarity,
"original_query": best_match.get("query"),
}
rag_metrics.record_cache_miss()
return None
except Exception as e:
# Don't let cache errors affect main flow
rag_metrics.record_cache_miss()
return None
async def set(
self,
query: str,
answer: str,
contexts: list[str],
metadata: Optional[Dict] = None,
):
"""Cache a RAG response."""
if not self.enabled:
return
await self.connect()
try:
query_embedding = await self._embed_async(query)
cache_key = f"rag:cache:{self._hash_query(query)}"
cache_data = {
"query": query,
"answer": answer,
"contexts": json.dumps(contexts),
"embedding": json.dumps(query_embedding),
"timestamp": datetime.utcnow().isoformat(),
"metadata": json.dumps(metadata or {}),
}
await self.redis_client.hset(cache_key, mapping=cache_data)
await self.redis_client.expire(cache_key, self.ttl)
except Exception:
pass # Silently fail on cache errors
async def invalidate(self, query: str):
"""Invalidate a cached response."""
if not self.enabled:
return
await self.connect()
cache_key = f"rag:cache:{self._hash_query(query)}"
await self.redis_client.delete(cache_key)
async def clear(self):
"""Clear all cached responses."""
if not self.enabled:
return
await self.connect()
keys = await self.redis_client.keys("rag:cache:*")
if keys:
await self.redis_client.delete(*keys)
async def _embed_async(self, text: str) -> list[float]:
"""Generate embedding asynchronously."""
return await self.embeddings.aembed_query(text)
def _cosine_similarity(self, a: list[float], b: list[float]) -> float:
"""Calculate cosine similarity between two vectors."""
a = np.array(a)
b = np.array(b)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
def _hash_query(self, query: str) -> str:
"""Generate hash for cache key."""
return hashlib.sha256(query.encode()).hexdigest()[:16]
# Exact match cache for frequently repeated queries
class ExactMatchCache:
"""Simple exact-match cache for common queries."""
def __init__(self):
self.enabled = settings.enable_cache
self.redis_client = None
self.ttl = settings.cache_ttl_seconds
async def connect(self):
if self.enabled and not self.redis_client:
self.redis_client = await redis.from_url(
settings.redis_url,
decode_responses=True,
)
async def get(self, query: str) -> Optional[str]:
"""Get exact match from cache."""
if not self.enabled:
return None
await self.connect()
key = f"rag:exact:{self._hash(query)}"
return await self.redis_client.get(key)
async def set(self, query: str, response: str):
"""Set exact match in cache."""
if not self.enabled:
return
await self.connect()
key = f"rag:exact:{self._hash(query)}"
await self.redis_client.setex(key, self.ttl, response)
def _hash(self, query: str) -> str:
return hashlib.sha256(query.lower().strip().encode()).hexdigest()[:16]Step 8: Circuit Breaker
Protect your system from cascading failures:
"""Circuit breaker for fault tolerance."""
import asyncio
from enum import Enum
from typing import Callable, Any, Optional
from datetime import datetime, timedelta
from dataclasses import dataclass, field
from src.observability.metrics import rag_metrics
class CircuitState(Enum):
CLOSED = 0 # Normal operation
OPEN = 1 # Failing, reject requests
HALF_OPEN = 2 # Testing if service recovered
@dataclass
class CircuitBreaker:
"""Circuit breaker with configurable thresholds.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Service failing, requests rejected immediately
- HALF_OPEN: Testing if service recovered
Transitions:
- CLOSED -> OPEN: failure_threshold failures within window
- OPEN -> HALF_OPEN: After recovery_timeout
- HALF_OPEN -> CLOSED: success_threshold successes
- HALF_OPEN -> OPEN: Any failure
"""
name: str
failure_threshold: int = 5
success_threshold: int = 2
recovery_timeout: int = 30 # seconds
failure_window: int = 60 # seconds
state: CircuitState = field(default=CircuitState.CLOSED)
failures: list = field(default_factory=list)
successes: int = field(default=0)
last_failure_time: Optional[datetime] = field(default=None)
last_state_change: datetime = field(default_factory=datetime.utcnow)
def __post_init__(self):
self._lock = asyncio.Lock()
async def call(
self,
func: Callable,
*args,
fallback: Optional[Callable] = None,
**kwargs,
) -> Any:
"""Execute function through circuit breaker."""
async with self._lock:
self._update_state()
if self.state == CircuitState.OPEN:
if fallback:
return await self._execute_fallback(fallback, *args, **kwargs)
raise CircuitOpenError(f"Circuit {self.name} is OPEN")
try:
result = await func(*args, **kwargs)
await self._record_success()
return result
except Exception as e:
await self._record_failure()
if fallback and self.state == CircuitState.OPEN:
return await self._execute_fallback(fallback, *args, **kwargs)
raise
async def _record_success(self):
"""Record successful execution."""
async with self._lock:
if self.state == CircuitState.HALF_OPEN:
self.successes += 1
if self.successes >= self.success_threshold:
self._transition_to(CircuitState.CLOSED)
async def _record_failure(self):
"""Record failed execution."""
async with self._lock:
now = datetime.utcnow()
self.last_failure_time = now
# Add to failures list
self.failures.append(now)
# Clean old failures outside window
cutoff = now - timedelta(seconds=self.failure_window)
self.failures = [f for f in self.failures if f > cutoff]
if self.state == CircuitState.HALF_OPEN:
self._transition_to(CircuitState.OPEN)
elif (
self.state == CircuitState.CLOSED
and len(self.failures) >= self.failure_threshold
):
self._transition_to(CircuitState.OPEN)
def _update_state(self):
"""Check if state should change based on time."""
if self.state == CircuitState.OPEN:
if self.last_failure_time:
elapsed = (datetime.utcnow() - self.last_failure_time).total_seconds()
if elapsed >= self.recovery_timeout:
self._transition_to(CircuitState.HALF_OPEN)
def _transition_to(self, new_state: CircuitState):
"""Transition to new state."""
self.state = new_state
self.last_state_change = datetime.utcnow()
rag_metrics.set_circuit_state(new_state.value)
if new_state == CircuitState.CLOSED:
self.failures = []
self.successes = 0
elif new_state == CircuitState.HALF_OPEN:
self.successes = 0
async def _execute_fallback(self, fallback: Callable, *args, **kwargs) -> Any:
"""Execute fallback function."""
if asyncio.iscoroutinefunction(fallback):
return await fallback(*args, **kwargs)
return fallback(*args, **kwargs)
@property
def is_available(self) -> bool:
"""Check if circuit allows requests."""
self._update_state()
return self.state != CircuitState.OPEN
class CircuitOpenError(Exception):
"""Raised when circuit breaker is open."""
pass
# Pre-configured circuit breakers
llm_circuit = CircuitBreaker(
name="llm",
failure_threshold=3,
recovery_timeout=60,
)
retrieval_circuit = CircuitBreaker(
name="retrieval",
failure_threshold=5,
recovery_timeout=30,
)Step 9: Rate Limiter
"""Token bucket rate limiter."""
import asyncio
from typing import Optional
from datetime import datetime, timedelta
from dataclasses import dataclass
import redis.asyncio as redis
from src.config import settings
@dataclass
class RateLimitResult:
"""Rate limit check result."""
allowed: bool
remaining: int
reset_at: datetime
retry_after: Optional[int] = None
class RateLimiter:
"""Token bucket rate limiter with Redis backend."""
def __init__(
self,
requests_per_window: int = None,
window_seconds: int = None,
):
self.requests = requests_per_window or settings.rate_limit_requests
self.window = window_seconds or settings.rate_limit_window
self.redis_client = None
async def connect(self):
"""Initialize Redis connection."""
if not self.redis_client:
self.redis_client = await redis.from_url(
settings.redis_url,
decode_responses=True,
)
async def check(self, key: str) -> RateLimitResult:
"""Check if request is allowed.
Args:
key: Unique identifier (e.g., user_id, ip_address)
Returns:
RateLimitResult with allowed status and metadata
"""
await self.connect()
now = datetime.utcnow()
window_start = now.replace(second=0, microsecond=0)
redis_key = f"ratelimit:{key}:{window_start.timestamp()}"
# Increment counter
current = await self.redis_client.incr(redis_key)
# Set expiry on first request in window
if current == 1:
await self.redis_client.expire(redis_key, self.window)
remaining = max(0, self.requests - current)
reset_at = window_start + timedelta(seconds=self.window)
if current > self.requests:
retry_after = int((reset_at - now).total_seconds())
return RateLimitResult(
allowed=False,
remaining=0,
reset_at=reset_at,
retry_after=retry_after,
)
return RateLimitResult(
allowed=True,
remaining=remaining,
reset_at=reset_at,
)
async def reset(self, key: str):
"""Reset rate limit for a key."""
await self.connect()
pattern = f"ratelimit:{key}:*"
keys = await self.redis_client.keys(pattern)
if keys:
await self.redis_client.delete(*keys)
# Global rate limiter
rate_limiter = RateLimiter()Step 10: A/B Testing Framework
"""A/B testing router for RAG experiments."""
import hashlib
import random
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from src.config import settings
from src.observability.metrics import rag_metrics
class AllocationStrategy(Enum):
RANDOM = "random" # Random assignment
DETERMINISTIC = "deterministic" # Hash-based (same user = same variant)
STICKY = "sticky" # Remember assignment
@dataclass
class Variant:
"""Experiment variant configuration."""
name: str
weight: float # 0.0 to 1.0
config: Dict[str, Any] = field(default_factory=dict)
@dataclass
class Experiment:
"""A/B test experiment definition."""
name: str
variants: List[Variant]
allocation_strategy: AllocationStrategy = AllocationStrategy.DETERMINISTIC
enabled: bool = True
def __post_init__(self):
# Validate weights sum to 1.0
total_weight = sum(v.weight for v in self.variants)
if abs(total_weight - 1.0) > 0.01:
raise ValueError(f"Variant weights must sum to 1.0, got {total_weight}")
@dataclass
class Assignment:
"""Experiment assignment result."""
experiment: str
variant: str
config: Dict[str, Any]
timestamp: datetime = field(default_factory=datetime.utcnow)
class ABRouter:
"""Routes requests to experiment variants."""
def __init__(self):
self.enabled = settings.enable_ab_testing
self.experiments: Dict[str, Experiment] = {}
self._sticky_assignments: Dict[str, Dict[str, str]] = {}
def register_experiment(self, experiment: Experiment):
"""Register an experiment."""
self.experiments[experiment.name] = experiment
def get_assignment(
self,
experiment_name: str,
user_id: str,
) -> Optional[Assignment]:
"""Get variant assignment for a user.
Args:
experiment_name: Name of the experiment
user_id: Unique user identifier
Returns:
Assignment with variant config, or None if experiment not found
"""
if not self.enabled:
return None
experiment = self.experiments.get(experiment_name)
if not experiment or not experiment.enabled:
return None
variant = self._select_variant(experiment, user_id)
# Record metrics
rag_metrics.record_ab_request(experiment_name, variant.name)
return Assignment(
experiment=experiment_name,
variant=variant.name,
config=variant.config,
)
def _select_variant(self, experiment: Experiment, user_id: str) -> Variant:
"""Select variant based on allocation strategy."""
strategy = experiment.allocation_strategy
if strategy == AllocationStrategy.RANDOM:
return self._random_select(experiment.variants)
elif strategy == AllocationStrategy.DETERMINISTIC:
return self._deterministic_select(experiment, user_id)
elif strategy == AllocationStrategy.STICKY:
return self._sticky_select(experiment, user_id)
return experiment.variants[0] # Fallback to first variant
def _random_select(self, variants: List[Variant]) -> Variant:
"""Randomly select variant based on weights."""
r = random.random()
cumulative = 0.0
for variant in variants:
cumulative += variant.weight
if r <= cumulative:
return variant
return variants[-1]
def _deterministic_select(self, experiment: Experiment, user_id: str) -> Variant:
"""Deterministically select variant based on user_id hash."""
hash_input = f"{experiment.name}:{user_id}"
hash_value = int(hashlib.sha256(hash_input.encode()).hexdigest(), 16)
bucket = (hash_value % 1000) / 1000.0
cumulative = 0.0
for variant in experiment.variants:
cumulative += variant.weight
if bucket <= cumulative:
return variant
return experiment.variants[-1]
def _sticky_select(self, experiment: Experiment, user_id: str) -> Variant:
"""Select and remember variant for user."""
if experiment.name not in self._sticky_assignments:
self._sticky_assignments[experiment.name] = {}
assignments = self._sticky_assignments[experiment.name]
if user_id in assignments:
variant_name = assignments[user_id]
for v in experiment.variants:
if v.name == variant_name:
return v
# New assignment
variant = self._deterministic_select(experiment, user_id)
assignments[user_id] = variant.name
return variant
# Pre-configured experiments
def setup_experiments(router: ABRouter):
"""Set up default experiments."""
# Retrieval strategy experiment
router.register_experiment(Experiment(
name="retrieval_strategy",
variants=[
Variant(
name="semantic_only",
weight=0.5,
config={"retriever": "semantic", "top_k": 4},
),
Variant(
name="hybrid",
weight=0.5,
config={"retriever": "hybrid", "top_k": 4, "alpha": 0.5},
),
],
))
# Reranking experiment
router.register_experiment(Experiment(
name="reranking",
variants=[
Variant(
name="no_rerank",
weight=0.5,
config={"rerank": False},
),
Variant(
name="with_rerank",
weight=0.5,
config={"rerank": True, "rerank_top_n": 3},
),
],
))
# Global router
ab_router = ABRouter()
setup_experiments(ab_router)Step 11: Core RAG Pipeline
"""Production RAG pipeline with all middleware."""
import asyncio
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from src.config import settings
from src.middleware.cache import SemanticCache
from src.middleware.circuit_breaker import llm_circuit, retrieval_circuit
from src.middleware.rate_limiter import rate_limiter
from src.experiments.ab_router import ab_router
from src.evaluation.evaluator import OnlineEvaluator
from src.observability.tracer import LangSmithTracer
from src.observability.metrics import rag_metrics, timed
@dataclass
class RAGResponse:
"""Complete RAG response with metadata."""
answer: str
query: str
contexts: List[str]
sources: List[Dict[str, Any]]
cached: bool = False
variant: Optional[str] = None
evaluation: Optional[Dict[str, float]] = None
latency_ms: Optional[float] = None
class ProductionRAGPipeline:
"""Production-ready RAG pipeline with full observability."""
def __init__(self):
# Core components
self.embeddings = OpenAIEmbeddings(
model=settings.embedding_model,
openai_api_key=settings.openai_api_key,
)
self.vectorstore = Chroma(
persist_directory=settings.chroma_persist_dir,
embedding_function=self.embeddings,
)
self.llm = ChatOpenAI(
model=settings.llm_model,
openai_api_key=settings.openai_api_key,
temperature=0.1,
)
# Middleware
self.cache = SemanticCache()
self.evaluator = OnlineEvaluator(sample_rate=0.1)
self.tracer = LangSmithTracer()
# Prompt
self.qa_prompt = ChatPromptTemplate.from_messages([
("system", """You are a helpful assistant that answers questions based on the provided context.
Guidelines:
- Answer only based on the context provided
- If the context doesn't contain the answer, say so clearly
- Cite specific parts of the context when possible
- Be concise but thorough"""),
("human", """Context:
{context}
Question: {question}
Answer:"""),
])
async def query(
self,
query: str,
user_id: str = "anonymous",
session_id: Optional[str] = None,
) -> RAGResponse:
"""Process a query through the production pipeline."""
import time
start_time = time.perf_counter()
# Rate limiting
rate_result = await rate_limiter.check(user_id)
if not rate_result.allowed:
raise RateLimitExceeded(
retry_after=rate_result.retry_after,
message="Rate limit exceeded",
)
# Check cache
cached = await self.cache.get(query)
if cached:
return RAGResponse(
answer=cached["answer"],
query=query,
contexts=cached.get("contexts", []),
sources=[],
cached=True,
latency_ms=(time.perf_counter() - start_time) * 1000,
)
# Get A/B test assignment
assignment = ab_router.get_assignment("retrieval_strategy", user_id)
variant_config = assignment.config if assignment else {}
variant_name = assignment.variant if assignment else "default"
# Trace the request
with self.tracer.trace_query(query, session_id, {"variant": variant_name}) as trace:
try:
# Retrieve
docs = await self._retrieve_with_circuit_breaker(
query,
variant_config,
)
self.tracer.trace_retrieval(trace, query, docs, variant_name)
contexts = [doc.page_content for doc in docs]
context_text = "\n\n".join(contexts)
# Generate
answer = await self._generate_with_circuit_breaker(
query,
context_text,
)
self.tracer.trace_generation(
trace,
f"Query: {query}\nContext length: {len(context_text)}",
answer,
settings.llm_model,
)
# Cache the response
await self.cache.set(query, answer, contexts)
# Async evaluation (non-blocking)
evaluation_task = asyncio.create_task(
self.evaluator.evaluate_async(query, answer, contexts)
)
# Build response
latency_ms = (time.perf_counter() - start_time) * 1000
rag_metrics.record_request("success", variant_name)
rag_metrics.record_latency("total", latency_ms / 1000)
response = RAGResponse(
answer=answer,
query=query,
contexts=contexts,
sources=[
{
"content": doc.page_content[:200],
"metadata": doc.metadata,
}
for doc in docs
],
cached=False,
variant=variant_name,
latency_ms=latency_ms,
)
# Wait for evaluation (with timeout)
try:
evaluation = await asyncio.wait_for(evaluation_task, timeout=5.0)
response.evaluation = evaluation
self.tracer.trace_evaluation(trace, evaluation or {})
except asyncio.TimeoutError:
pass
return response
except Exception as e:
rag_metrics.record_request("error", variant_name)
raise
@timed("retrieval")
async def _retrieve_with_circuit_breaker(
self,
query: str,
config: Dict[str, Any],
) -> List[Document]:
"""Retrieve documents with circuit breaker protection."""
async def retrieval_fn():
top_k = config.get("top_k", settings.retrieval_top_k)
retriever = self.vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": top_k},
)
docs = await retriever.ainvoke(query)
rag_metrics.record_retrieval(len(docs))
return docs
async def fallback_fn():
# Return empty results on failure
return []
return await retrieval_circuit.call(
retrieval_fn,
fallback=fallback_fn,
)
@timed("generation")
async def _generate_with_circuit_breaker(
self,
query: str,
context: str,
) -> str:
"""Generate response with circuit breaker protection."""
async def generation_fn():
chain = self.qa_prompt | self.llm | StrOutputParser()
return await chain.ainvoke({
"context": context,
"question": query,
})
async def fallback_fn():
return "I'm sorry, but I'm currently unable to process your request. Please try again later."
return await llm_circuit.call(
generation_fn,
fallback=fallback_fn,
)
async def ingest_documents(self, documents: List[Document]) -> int:
"""Ingest documents into the vector store."""
self.vectorstore.add_documents(documents)
return len(documents)
class RateLimitExceeded(Exception):
"""Raised when rate limit is exceeded."""
def __init__(self, retry_after: int, message: str):
self.retry_after = retry_after
super().__init__(message)Step 12: FastAPI Application
"""Production FastAPI application."""
import asyncio
from typing import Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from pydantic import BaseModel
from src.config import settings
from src.core.pipeline import ProductionRAGPipeline, RateLimitExceeded
from src.observability.metrics import rag_metrics
from src.middleware.rate_limiter import rate_limiter
# Global pipeline
pipeline: Optional[ProductionRAGPipeline] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize resources on startup."""
global pipeline
pipeline = ProductionRAGPipeline()
await pipeline.cache.connect()
yield
# Cleanup
app = FastAPI(
title="Production RAG API",
description="Production-ready RAG system with evaluation, monitoring, and A/B testing",
version="1.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request/Response Models
class QueryRequest(BaseModel):
query: str
user_id: Optional[str] = "anonymous"
session_id: Optional[str] = None
class QueryResponse(BaseModel):
answer: str
query: str
sources: list
cached: bool
variant: Optional[str] = None
evaluation: Optional[dict] = None
latency_ms: Optional[float] = None
class HealthResponse(BaseModel):
status: str
cache_enabled: bool
tracing_enabled: bool
ab_testing_enabled: bool
# Endpoints
@app.get("/", response_model=HealthResponse)
async def health():
"""Health check endpoint."""
return HealthResponse(
status="healthy",
cache_enabled=settings.enable_cache,
tracing_enabled=settings.enable_tracing,
ab_testing_enabled=settings.enable_ab_testing,
)
@app.get("/metrics")
async def metrics():
"""Prometheus metrics endpoint."""
return Response(
content=rag_metrics.get_metrics(),
media_type="text/plain",
)
@app.post("/query", response_model=QueryResponse)
async def query(request: QueryRequest):
"""Process a RAG query."""
try:
response = await pipeline.query(
query=request.query,
user_id=request.user_id,
session_id=request.session_id,
)
return QueryResponse(
answer=response.answer,
query=response.query,
sources=response.sources,
cached=response.cached,
variant=response.variant,
evaluation=response.evaluation,
latency_ms=response.latency_ms,
)
except RateLimitExceeded as e:
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Retry after {e.retry_after} seconds.",
headers={"Retry-After": str(e.retry_after)},
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/cache/clear")
async def clear_cache():
"""Clear the semantic cache."""
await pipeline.cache.clear()
return {"message": "Cache cleared"}
@app.get("/experiments")
async def list_experiments():
"""List active A/B experiments."""
from src.experiments.ab_router import ab_router
experiments = []
for name, exp in ab_router.experiments.items():
experiments.append({
"name": name,
"enabled": exp.enabled,
"variants": [
{"name": v.name, "weight": v.weight}
for v in exp.variants
],
})
return {"experiments": experiments}
# Run with: uvicorn src.api:app --reload
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)Docker Compose Setup
version: '3.8'
services:
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
postgres:
image: postgres:15-alpine
environment:
POSTGRES_USER: rag
POSTGRES_PASSWORD: ragpass
POSTGRES_DB: rag_experiments
ports:
- "5432:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
prometheus:
image: prom/prometheus:latest
ports:
- "9090:9090"
volumes:
- ./prometheus.yml:/etc/prometheus/prometheus.yml
command:
- '--config.file=/etc/prometheus/prometheus.yml'
rag-api:
build: .
ports:
- "8000:8000"
environment:
- REDIS_URL=redis://redis:6379
- DATABASE_URL=postgresql+asyncpg://rag:ragpass@postgres:5432/rag_experiments
depends_on:
- redis
- postgres
volumes:
redis_data:
postgres_data:global:
scrape_interval: 15s
scrape_configs:
- job_name: 'rag-api'
static_configs:
- targets: ['rag-api:8000']
metrics_path: '/metrics'Testing
"""Tests for RAG evaluation."""
import pytest
from src.evaluation.metrics import RAGSample, RAGASEvaluator, EvaluationResult
class TestEvaluationResult:
"""Tests for EvaluationResult."""
def test_overall_score_without_recall(self):
"""Test overall score calculation without context recall."""
result = EvaluationResult(
faithfulness=0.8,
answer_relevancy=0.9,
context_precision=0.7,
)
# Average of 0.8, 0.9, 0.7 = 0.8
assert abs(result.overall_score - 0.8) < 0.01
def test_overall_score_with_recall(self):
"""Test overall score calculation with context recall."""
result = EvaluationResult(
faithfulness=0.8,
answer_relevancy=0.8,
context_precision=0.8,
context_recall=0.8,
)
assert result.overall_score == 0.8
def test_to_dict(self):
"""Test dictionary conversion."""
result = EvaluationResult(
faithfulness=0.8,
answer_relevancy=0.9,
context_precision=0.7,
)
d = result.to_dict()
assert "faithfulness" in d
assert "overall_score" in d
class TestRAGSample:
"""Tests for RAGSample."""
def test_sample_creation(self):
"""Test creating a RAG sample."""
sample = RAGSample(
question="What is RAG?",
answer="RAG stands for Retrieval-Augmented Generation.",
contexts=["RAG combines retrieval with generation."],
)
assert sample.question == "What is RAG?"
assert sample.ground_truth is None"""Tests for semantic cache."""
import pytest
import asyncio
from unittest.mock import AsyncMock, patch
from src.middleware.cache import SemanticCache
class TestSemanticCache:
"""Tests for SemanticCache."""
@pytest.fixture
def cache(self):
"""Create cache instance."""
cache = SemanticCache()
cache.enabled = True
return cache
@pytest.mark.asyncio
async def test_cache_disabled(self):
"""Test cache returns None when disabled."""
cache = SemanticCache()
cache.enabled = False
result = await cache.get("test query")
assert result is None
def test_cosine_similarity(self):
"""Test cosine similarity calculation."""
cache = SemanticCache()
# Same vectors should have similarity 1.0
vec = [1.0, 0.0, 0.0]
assert cache._cosine_similarity(vec, vec) == 1.0
# Orthogonal vectors should have similarity 0.0
vec1 = [1.0, 0.0, 0.0]
vec2 = [0.0, 1.0, 0.0]
assert abs(cache._cosine_similarity(vec1, vec2)) < 0.01
def test_hash_query(self):
"""Test query hashing."""
cache = SemanticCache()
hash1 = cache._hash_query("test query")
hash2 = cache._hash_query("test query")
hash3 = cache._hash_query("different query")
assert hash1 == hash2
assert hash1 != hash3"""Tests for A/B testing."""
import pytest
from src.experiments.ab_router import (
ABRouter,
Experiment,
Variant,
AllocationStrategy,
)
class TestABRouter:
"""Tests for A/B router."""
@pytest.fixture
def router(self):
"""Create router with test experiment."""
router = ABRouter()
router.enabled = True
router.register_experiment(Experiment(
name="test_exp",
variants=[
Variant(name="control", weight=0.5, config={"version": "A"}),
Variant(name="treatment", weight=0.5, config={"version": "B"}),
],
))
return router
def test_deterministic_assignment(self, router):
"""Test same user gets same variant."""
assignment1 = router.get_assignment("test_exp", "user123")
assignment2 = router.get_assignment("test_exp", "user123")
assert assignment1.variant == assignment2.variant
def test_different_users_can_get_different_variants(self, router):
"""Test different users can be assigned to different variants."""
assignments = set()
for i in range(100):
assignment = router.get_assignment("test_exp", f"user_{i}")
assignments.add(assignment.variant)
# With 100 users, we should have both variants
assert len(assignments) == 2
def test_disabled_router_returns_none(self, router):
"""Test disabled router returns None."""
router.enabled = False
assignment = router.get_assignment("test_exp", "user123")
assert assignment is None
def test_invalid_weights_raise_error(self):
"""Test experiment with invalid weights raises error."""
with pytest.raises(ValueError):
Experiment(
name="bad_exp",
variants=[
Variant(name="a", weight=0.3),
Variant(name="b", weight=0.3), # Sum = 0.6, not 1.0
],
)Ground Truth Dataset
{
"version": "1.0",
"description": "Ground truth dataset for RAG evaluation",
"questions": [
{
"id": "q1",
"question": "What is RAG and how does it work?",
"expected_answer": "RAG (Retrieval-Augmented Generation) is a technique that combines information retrieval with language model generation. It works by first retrieving relevant documents from a knowledge base based on the user's query, then using those documents as context for the language model to generate a response.",
"category": "concept"
},
{
"id": "q2",
"question": "What are the main components of a RAG system?",
"expected_answer": "The main components of a RAG system are: 1) Document store/knowledge base, 2) Embedding model for vectorizing documents and queries, 3) Vector database for similarity search, 4) Retriever for finding relevant documents, 5) Language model for generating responses.",
"category": "architecture"
},
{
"id": "q3",
"question": "What is the difference between semantic search and keyword search?",
"expected_answer": "Semantic search uses embeddings to find documents based on meaning similarity, understanding context and synonyms. Keyword search matches exact terms in documents. Semantic search is better for natural language queries while keyword search excels at finding specific terms or identifiers.",
"category": "concept"
}
]
}Running the Application
Start Dependencies
docker-compose up -d redis postgres prometheusRun the API
uvicorn src.api:app --reload --host 0.0.0.0 --port 8000Test the Endpoints
# Health check
curl http://localhost:8000/
# Query with evaluation
curl -X POST http://localhost:8000/query \
-H "Content-Type: application/json" \
-d '{"query": "What is RAG?", "user_id": "test_user"}'
# View metrics
curl http://localhost:8000/metrics
# View experiments
curl http://localhost:8000/experiments
# Clear cache
curl -X POST http://localhost:8000/cache/clearAccess Dashboards
- API Docs: http://localhost:8000/docs
- Prometheus: http://localhost:9090
- LangSmith: https://smith.langchain.com
Monitoring Dashboard
Create a Prometheus query dashboard with these key metrics:
| Panel | Query | Purpose |
|---|---|---|
| Request Rate | rate(rag_requests_total[5m]) | Traffic volume |
| Error Rate | rate(rag_requests_total{status="error"}[5m]) | Error tracking |
| P95 Latency | histogram_quantile(0.95, rate(rag_request_latency_seconds_bucket[5m])) | Performance |
| Cache Hit Rate | rate(rag_cache_hits_total[5m]) / (rate(rag_cache_hits_total[5m]) + rate(rag_cache_misses_total[5m])) | Cache efficiency |
| Avg Faithfulness | histogram_quantile(0.5, rate(rag_faithfulness_score_bucket[5m])) | Quality |
Debugging Tips
Low faithfulness scores
- Check if retrieved contexts are relevant
- Verify chunk sizes aren't too small
- Review the generation prompt for grounding instructions
High cache miss rate
- Lower
cache_similarity_threshold(e.g., 0.90) - Check if queries are highly variable
- Verify embeddings are being generated correctly
Circuit breaker keeps opening
- Check LLM API availability
- Review error logs for root cause
- Increase
failure_thresholdif transient errors
A/B tests show no difference
- Ensure enough traffic for statistical significance
- Check variant configs are actually different
- Verify assignment is working correctly
Production Checklist
Before deploying to production:
- Configure appropriate rate limits
- Set up alerting on key metrics (error rate, latency, faithfulness)
- Configure Redis persistence for cache durability
- Set up log aggregation (ELK, CloudWatch, etc.)
- Enable HTTPS with proper certificates
- Configure authentication for API endpoints
- Set up database backups for experiment results
- Load test with expected traffic patterns
- Create runbooks for common issues
- Set up on-call rotation
Extensions
| Level | Ideas |
|---|---|
| Easy | Add request ID tracking, implement cache warming, add healthcheck for dependencies |
| Medium | Implement streaming responses, add user feedback collection, create Grafana dashboards |
| Advanced | Build automated regression testing, implement online learning from feedback, add multi-region deployment |
Key Takeaways
Resources
- RAGAS Documentation
- LangSmith Documentation
- Prometheus Best Practices
- Circuit Breaker Pattern
- A/B Testing at Scale
- Production ML Systems
Summary
You've built a production-grade RAG system that includes:
- Evaluation - RAGAS metrics with online sampling
- Observability - LangSmith tracing + Prometheus metrics
- Caching - Semantic caching for latency and cost reduction
- Reliability - Circuit breakers and rate limiting
- Experimentation - A/B testing framework for continuous improvement
This foundation allows you to iterate on RAG quality with confidence, catch regressions early, and make data-driven decisions about improvements.
Next: Graph RAG - Build knowledge graphs for complex reasoning