Adaptive RAG
Build a query complexity classifier that routes to optimal retrieval strategies
Adaptive RAG
TL;DR
Not every question needs retrieval. "What is Python?" wastes compute on retrieval—the LLM already knows. "Compare PostgreSQL vs MongoDB" needs multi-step retrieval to gather all relevant info. Adaptive RAG classifies query complexity and routes to the optimal strategy: no retrieval, single-step, or multi-step. This saves ~26% compute on average while improving answer quality for complex queries.
| Property | Value |
|---|---|
| Difficulty | Intermediate |
| Time | ~5 hours |
| Code Size | ~450 LOC |
| Prerequisites | Self-RAG |
Tech Stack
| Technology | Purpose |
|---|---|
| OpenAI | GPT-4 + Embeddings |
| LangChain | RAG orchestration |
| ChromaDB | Vector database |
| scikit-learn | Complexity classifier |
| Pydantic | Structured outputs |
| FastAPI | REST API |
Prerequisites
- Completed Self-RAG tutorial
- Python 3.10+
- OpenAI API key (Get one here)
What You'll Learn
- Implement query complexity classification
- Route queries to optimal retrieval strategies
- Build no-retrieval, single-step, and multi-step pipelines
- Train a lightweight classifier for routing decisions
- Optimize compute costs based on query complexity
Research Foundation
This project implements the concepts from Adaptive-RAG: Learning to Adapt Retrieval-Augmented Large Language Models through Question Complexity (NAACL 2024).
The Problem: One Size Doesn't Fit All
Traditional RAG uses the same retrieval strategy for every query. But queries vary dramatically in complexity:
| Query Type | Example | Optimal Strategy |
|---|---|---|
| Simple | "What is Python?" | No retrieval (LLM knows) |
| Moderate | "What are Python decorators?" | Single-step retrieval |
| Complex | "Compare Python async patterns with Go goroutines for web scraping" | Multi-step retrieval |
┌─────────────────────────────────────────────────────────────────┐
│ TRADITIONAL RAG ❌ │
│ │
│ Any Query ───► Always Retrieve ───► Generate │
│ │
│ (Same cost regardless of query complexity) │
└─────────────────────────────────────────────────────────────────┘Problems:
- Wastes compute on simple queries
- Under-retrieves for complex queries
- No adaptation to query characteristics
┌─────────────────────────────────────────────────────────────────┐
│ ADAPTIVE RAG ✅ │
│ │
│ ┌──────────────┐ │
│ │ Classifier │ │
│ └──────┬───────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌──────────┐ ┌───────────┐ ┌─────────────┐ │
│ │ Simple │ │ Moderate │ │ Complex │ │
│ │ │ │ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │ ▼ │ │
│ │No Retriev│ │Single-Step│ │ Multi-Step │ │
│ │(~500 tok)│ │(~2000 tok)│ │ (~5000 tok) │ │
│ └──────────┘ └───────────┘ └─────────────┘ │
│ │
│ (Right-sized compute for each query type) │
└─────────────────────────────────────────────────────────────────┘Project Structure
adaptive-rag/
├── config.py # Configuration
├── complexity_classifier.py # Query complexity prediction
├── strategies/
│ ├── __init__.py
│ ├── base.py # Strategy interface
│ ├── no_retrieval.py # Direct LLM answer
│ ├── single_step.py # Standard RAG
│ └── multi_step.py # Iterative retrieval
├── router.py # Strategy router
├── adaptive_rag.py # Main orchestration
├── trainer.py # Classifier training
├── app.py # FastAPI application
└── requirements.txtStep 1: Configuration
# config.py
from pydantic_settings import BaseSettings
from pydantic import Field
from functools import lru_cache
from enum import Enum
class ComplexityLevel(str, Enum):
"""Query complexity levels."""
SIMPLE = "simple" # No retrieval needed
MODERATE = "moderate" # Single-step retrieval
COMPLEX = "complex" # Multi-step retrieval
class Settings(BaseSettings):
"""Application configuration."""
openai_api_key: str
# Model settings
embedding_model: str = "text-embedding-3-small"
llm_model: str = "gpt-4o-mini"
classifier_model: str = "gpt-4o-mini"
# Retrieval settings
single_step_k: int = 5
multi_step_k: int = 3
multi_step_iterations: int = 3
# Classifier settings
use_llm_classifier: bool = True # False = use trained classifier
classifier_path: str = "./classifier.joblib"
# Complexity thresholds (for rule-based fallback)
simple_max_words: int = 6
complex_min_entities: int = 2
# ChromaDB
chroma_persist_dir: str = "./chroma_db"
collection_name: str = "adaptive_rag_docs"
class Config:
env_file = ".env"
@lru_cache
def get_settings() -> Settings:
return Settings()Step 2: Complexity Classifier
The core of Adaptive RAG: predicting query complexity to route appropriately.
# complexity_classifier.py
from openai import OpenAI
from pydantic import BaseModel, Field
from typing import Literal
import re
from config import get_settings, ComplexityLevel
class ComplexityPrediction(BaseModel):
"""Structured complexity prediction."""
level: ComplexityLevel
confidence: float = Field(ge=0, le=1)
reasoning: str
features: dict
class QueryFeatures(BaseModel):
"""Extracted query features for classification."""
word_count: int
has_comparison: bool
has_multiple_entities: bool
requires_reasoning: bool
is_factual: bool
temporal_scope: Literal["current", "historical", "comparative"]
estimated_steps: int
class LLMComplexityClassifier:
"""LLM-based complexity classifier."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.classifier_model
def classify(self, query: str) -> ComplexityPrediction:
"""Classify query complexity using LLM."""
system_prompt = """You are a query complexity classifier for a RAG system.
Analyze the query and determine its complexity level:
**SIMPLE** - Direct factual questions the LLM likely knows:
- General knowledge questions
- Definitions of common terms
- Simple facts (capitals, dates, basic concepts)
- Examples: "What is Python?", "Who wrote Romeo and Juliet?"
**MODERATE** - Questions requiring single retrieval step:
- Specific technical details
- Recent information not in training data
- Domain-specific facts
- Examples: "What are the new features in Python 3.12?", "Explain React hooks"
**COMPLEX** - Questions requiring multiple retrieval steps:
- Comparisons between multiple entities
- Multi-part questions
- Questions requiring synthesis from multiple sources
- Reasoning chains across documents
- Examples: "Compare PostgreSQL vs MongoDB for time-series data",
"How do Python, Go, and Rust handle concurrency differently?"
Return your analysis as JSON with:
- level: "simple", "moderate", or "complex"
- confidence: 0.0-1.0
- reasoning: Brief explanation
- features: Object with has_comparison, entity_count, requires_synthesis"""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Query: {query}"}
],
response_format={"type": "json_object"}
)
import json
result = json.loads(response.choices[0].message.content)
return ComplexityPrediction(
level=ComplexityLevel(result["level"]),
confidence=result.get("confidence", 0.8),
reasoning=result.get("reasoning", ""),
features=result.get("features", {})
)
class RuleBasedClassifier:
"""Fast rule-based classifier for latency-sensitive applications."""
def __init__(self):
settings = get_settings()
self.simple_max_words = settings.simple_max_words
self.complex_min_entities = settings.complex_min_entities
# Patterns indicating complexity
self.comparison_patterns = [
r'\bvs\.?\b', r'\bversus\b', r'\bcompare\b', r'\bdifference\b',
r'\bbetter\b', r'\bworse\b', r'\badvantages?\b', r'\bdisadvantages?\b'
]
self.multi_step_patterns = [
r'\band\b.*\band\b', # Multiple "and"s
r'\bhow\b.*\bwhy\b', # Multiple question types
r'\bfirst\b.*\bthen\b', # Sequential
r'\bstep.?by.?step\b'
]
self.simple_patterns = [
r'^what is\b', r'^who is\b', r'^when was\b',
r'^define\b', r'^what does .* mean'
]
def extract_features(self, query: str) -> QueryFeatures:
"""Extract features from query."""
query_lower = query.lower()
words = query.split()
# Check patterns
has_comparison = any(
re.search(p, query_lower) for p in self.comparison_patterns
)
has_multi_step = any(
re.search(p, query_lower) for p in self.multi_step_patterns
)
is_simple_pattern = any(
re.search(p, query_lower) for p in self.simple_patterns
)
# Count potential entities (capitalized words, technical terms)
entities = re.findall(r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', query)
entity_count = len(set(entities))
# Estimate steps needed
if has_multi_step or has_comparison:
estimated_steps = 3
elif entity_count > 1:
estimated_steps = 2
else:
estimated_steps = 1
return QueryFeatures(
word_count=len(words),
has_comparison=has_comparison,
has_multiple_entities=entity_count >= self.complex_min_entities,
requires_reasoning=has_multi_step,
is_factual=is_simple_pattern,
temporal_scope="current", # Simplified
estimated_steps=estimated_steps
)
def classify(self, query: str) -> ComplexityPrediction:
"""Classify using rules."""
features = self.extract_features(query)
# Decision logic
if features.is_factual and features.word_count <= self.simple_max_words:
level = ComplexityLevel.SIMPLE
confidence = 0.85
elif features.has_comparison or features.requires_reasoning:
level = ComplexityLevel.COMPLEX
confidence = 0.80
elif features.has_multiple_entities:
level = ComplexityLevel.COMPLEX
confidence = 0.75
else:
level = ComplexityLevel.MODERATE
confidence = 0.70
return ComplexityPrediction(
level=level,
confidence=confidence,
reasoning=f"Word count: {features.word_count}, "
f"Comparison: {features.has_comparison}, "
f"Multi-entity: {features.has_multiple_entities}",
features=features.model_dump()
)
class HybridClassifier:
"""Combines rule-based speed with LLM accuracy."""
def __init__(self):
self.rule_classifier = RuleBasedClassifier()
self.llm_classifier = LLMComplexityClassifier()
def classify(
self,
query: str,
use_llm_verification: bool = True
) -> ComplexityPrediction:
"""
Classify with optional LLM verification.
Uses rules first, then LLM if confidence is low.
"""
rule_prediction = self.rule_classifier.classify(query)
# If rule-based is confident, use it
if rule_prediction.confidence >= 0.8 and not use_llm_verification:
return rule_prediction
# Otherwise, verify with LLM
llm_prediction = self.llm_classifier.classify(query)
# If they agree, boost confidence
if rule_prediction.level == llm_prediction.level:
return ComplexityPrediction(
level=llm_prediction.level,
confidence=min(1.0, llm_prediction.confidence + 0.1),
reasoning=f"Rule and LLM agree: {llm_prediction.reasoning}",
features=llm_prediction.features
)
# If they disagree, trust LLM but note disagreement
return ComplexityPrediction(
level=llm_prediction.level,
confidence=llm_prediction.confidence * 0.9, # Slight penalty
reasoning=f"LLM override (rule said {rule_prediction.level}): "
f"{llm_prediction.reasoning}",
features=llm_prediction.features
)Understanding the Three Classifiers:
┌─────────────────────────────────────────────────────────────┐
│ APPROACH 1: LLM Classifier │
│ │
│ Query ──► GPT-4 ──► { level, confidence, reasoning } │
│ │
│ Pros: Accurate, understands nuance │
│ Cons: ~500ms latency, costs tokens │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ APPROACH 2: Rule-Based Classifier │
│ │
│ Query ──► Regex patterns ──► { level, confidence } │
│ │
│ Pros: Fast (~1ms), free │
│ Cons: Misses subtle complexity │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ APPROACH 3: Hybrid Classifier (Recommended) │
│ │
│ Query ──► Rules first ──► if uncertain ──► LLM verifies │
│ │
│ Pros: Best of both—fast when easy, accurate when hard │
└─────────────────────────────────────────────────────────────┘The Rule-Based Feature Detection:
| Pattern | Example Match | Indicates |
|---|---|---|
\bvs\.?\b | "Python vs Java" | Comparison → Complex |
\band\b.*\band\b | "A and B and C" | Multi-part → Complex |
^what is\b | "What is Python?" | Simple factual |
| Capitalized words | "PostgreSQL", "MongoDB" | Entities to count |
Why Hybrid Works Best:
# Rule says SIMPLE, LLM agrees → High confidence, use SIMPLE
# Rule says SIMPLE, LLM says COMPLEX → Trust LLM (it saw nuance)
# Rule unsure (confidence < 0.8) → Always verify with LLMStep 3: Retrieval Strategies
# strategies/base.py
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import Any
class StrategyResult(BaseModel):
"""Result from a retrieval strategy."""
answer: str
sources: list[dict]
strategy_used: str
retrieval_steps: int
tokens_used: int | None = None
latency_ms: float
class RetrievalStrategy(ABC):
"""Base class for retrieval strategies."""
@property
@abstractmethod
def name(self) -> str:
pass
@abstractmethod
async def execute(self, query: str) -> StrategyResult:
pass# strategies/no_retrieval.py
import time
from openai import OpenAI
from strategies.base import RetrievalStrategy, StrategyResult
from config import get_settings
class NoRetrievalStrategy(RetrievalStrategy):
"""Direct LLM answer without retrieval."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.llm_model
@property
def name(self) -> str:
return "no_retrieval"
async def execute(self, query: str) -> StrategyResult:
start = time.time()
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": "Answer the question directly and concisely. "
"If you're not certain, say so."
},
{"role": "user", "content": query}
]
)
latency = (time.time() - start) * 1000
return StrategyResult(
answer=response.choices[0].message.content,
sources=[],
strategy_used=self.name,
retrieval_steps=0,
tokens_used=response.usage.total_tokens if response.usage else None,
latency_ms=latency
)# strategies/single_step.py
import time
import chromadb
from chromadb.utils import embedding_functions
from openai import OpenAI
from strategies.base import RetrievalStrategy, StrategyResult
from config import get_settings
class SingleStepStrategy(RetrievalStrategy):
"""Standard single-step RAG."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.llm_model
self.k = settings.single_step_k
# ChromaDB setup
self.chroma = chromadb.PersistentClient(path=settings.chroma_persist_dir)
self.embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
api_key=settings.openai_api_key,
model_name=settings.embedding_model
)
self.collection = self.chroma.get_or_create_collection(
name=settings.collection_name,
embedding_function=self.embedding_fn
)
@property
def name(self) -> str:
return "single_step"
async def execute(self, query: str) -> StrategyResult:
start = time.time()
# Retrieve
results = self.collection.query(
query_texts=[query],
n_results=self.k,
include=["documents", "metadatas"]
)
# Build context
context = "\n\n---\n\n".join([
f"Source: {results['metadatas'][0][i].get('source', 'unknown')}\n"
f"{results['documents'][0][i]}"
for i in range(len(results['documents'][0]))
])
# Generate
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": "Answer based on the provided context. "
"Cite sources when possible."
},
{
"role": "user",
"content": f"Context:\n{context}\n\nQuestion: {query}"
}
]
)
latency = (time.time() - start) * 1000
sources = [
{"source": results['metadatas'][0][i].get('source', 'unknown')}
for i in range(len(results['documents'][0]))
]
return StrategyResult(
answer=response.choices[0].message.content,
sources=sources,
strategy_used=self.name,
retrieval_steps=1,
tokens_used=response.usage.total_tokens if response.usage else None,
latency_ms=latency
)
def add_documents(self, documents: list[str], sources: list[str]):
"""Add documents to the collection."""
ids = [f"doc_{i}" for i in range(len(documents))]
self.collection.add(
documents=documents,
ids=ids,
metadatas=[{"source": src} for src in sources]
)# strategies/multi_step.py
import time
import chromadb
from chromadb.utils import embedding_functions
from openai import OpenAI
from pydantic import BaseModel
from strategies.base import RetrievalStrategy, StrategyResult
from config import get_settings
class SubQuery(BaseModel):
"""A decomposed sub-query."""
query: str
purpose: str
class MultiStepStrategy(RetrievalStrategy):
"""Multi-step iterative retrieval for complex queries."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.llm_model
self.k = settings.multi_step_k
self.max_iterations = settings.multi_step_iterations
# ChromaDB setup
self.chroma = chromadb.PersistentClient(path=settings.chroma_persist_dir)
self.embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
api_key=settings.openai_api_key,
model_name=settings.embedding_model
)
self.collection = self.chroma.get_or_create_collection(
name=settings.collection_name,
embedding_function=self.embedding_fn
)
@property
def name(self) -> str:
return "multi_step"
def decompose_query(self, query: str) -> list[SubQuery]:
"""Decompose complex query into sub-queries."""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": """Decompose the complex query into 2-4 simpler sub-queries.
Each sub-query should retrieve specific information needed to answer the main query.
Return JSON: {"sub_queries": [{"query": "...", "purpose": "..."}]}"""
},
{"role": "user", "content": f"Query: {query}"}
],
response_format={"type": "json_object"}
)
import json
result = json.loads(response.choices[0].message.content)
return [SubQuery(**sq) for sq in result.get("sub_queries", [{"query": query, "purpose": "main"}])]
async def execute(self, query: str) -> StrategyResult:
start = time.time()
# Decompose query
sub_queries = self.decompose_query(query)
# Retrieve for each sub-query
all_contexts = []
all_sources = []
for sq in sub_queries:
results = self.collection.query(
query_texts=[sq.query],
n_results=self.k,
include=["documents", "metadatas"]
)
for i in range(len(results['documents'][0])):
all_contexts.append({
"sub_query": sq.query,
"purpose": sq.purpose,
"content": results['documents'][0][i],
"source": results['metadatas'][0][i].get('source', 'unknown')
})
all_sources.append({
"source": results['metadatas'][0][i].get('source', 'unknown'),
"sub_query": sq.query
})
# Build comprehensive context
context_text = ""
for ctx in all_contexts:
context_text += f"\n[For: {ctx['purpose']}]\n"
context_text += f"Source: {ctx['source']}\n"
context_text += f"{ctx['content']}\n"
context_text += "---\n"
# Synthesize answer
response = self.client.chat.completions.create(
model=self.model,
messages=[
{
"role": "system",
"content": """Synthesize a comprehensive answer from the multi-source context.
Address all aspects of the complex query.
Cite sources and explain how different pieces connect."""
},
{
"role": "user",
"content": f"Context:\n{context_text}\n\nOriginal Question: {query}"
}
]
)
latency = (time.time() - start) * 1000
return StrategyResult(
answer=response.choices[0].message.content,
sources=all_sources,
strategy_used=self.name,
retrieval_steps=len(sub_queries),
tokens_used=response.usage.total_tokens if response.usage else None,
latency_ms=latency
)Understanding the Three Strategies:
┌─────────────────────────────────────────────────────────────┐
│ STRATEGY 1: No Retrieval (Simple Queries) │
│ │
│ "What is Python?" │
│ │ │
│ ▼ │
│ ┌──────────┐ │
│ │ LLM │ ──► "Python is a high-level programming..." │
│ └──────────┘ │
│ │
│ Cost: ~500 tokens | Latency: ~300ms │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ STRATEGY 2: Single-Step (Moderate Queries) │
│ │
│ "What are the new features in Python 3.12?" │
│ │ │
│ ▼ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ Retrieve │ ──► │ Context │ ──► │ LLM │ ──► Answer │
│ │ (k=5) │ │ Build │ │ │ │
│ └──────────┘ └──────────┘ └──────────┘ │
│ │
│ Cost: ~2000 tokens | Latency: ~800ms │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ STRATEGY 3: Multi-Step (Complex Queries) │
│ │
│ "Compare Python async with Go goroutines" │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Decompose │ ──► ["Python async?", "Go goroutines?", │
│ │ Query │ "How do they compare?"] │
│ └──────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Retrieve │ │ Retrieve │ │ Retrieve │ │
│ │ Sub-Q 1 │ │ Sub-Q 2 │ │ Sub-Q 3 │ │
│ └──────────────┘ └──────────────┘ └──────────────┘ │
│ │ │ │ │
│ └─────────────────────┴───────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Synthesize │ ──► Comprehensive Answer│
│ │ Answer │ │
│ └──────────────┘ │
│ │
│ Cost: ~5000 tokens | Latency: ~2000ms │
└─────────────────────────────────────────────────────────────┘Why Query Decomposition Matters for Complex Queries:
| Original Query | Sub-Queries | Why Better |
|---|---|---|
| "Compare Python async with Go goroutines" | 1. "How does Python async work?" | Each sub-query retrieves focused, relevant docs |
| 2. "How do Go goroutines work?" | No single doc covers both well | |
| 3. "Async vs goroutines comparison" | Synthesis combines insights |
Step 4: Strategy Router
# router.py
from complexity_classifier import (
HybridClassifier,
ComplexityPrediction
)
from strategies.base import RetrievalStrategy, StrategyResult
from strategies.no_retrieval import NoRetrievalStrategy
from strategies.single_step import SingleStepStrategy
from strategies.multi_step import MultiStepStrategy
from config import ComplexityLevel
class StrategyRouter:
"""Routes queries to appropriate retrieval strategies."""
def __init__(self):
self.classifier = HybridClassifier()
# Initialize strategies
self.strategies: dict[ComplexityLevel, RetrievalStrategy] = {
ComplexityLevel.SIMPLE: NoRetrievalStrategy(),
ComplexityLevel.MODERATE: SingleStepStrategy(),
ComplexityLevel.COMPLEX: MultiStepStrategy(),
}
def get_strategy(self, level: ComplexityLevel) -> RetrievalStrategy:
"""Get strategy for complexity level."""
return self.strategies[level]
async def route_and_execute(
self,
query: str,
force_strategy: ComplexityLevel | None = None
) -> tuple[StrategyResult, ComplexityPrediction]:
"""
Classify query and route to appropriate strategy.
Args:
query: User query
force_strategy: Override classification with specific strategy
Returns:
Tuple of (result, classification)
"""
# Classify
if force_strategy:
classification = ComplexityPrediction(
level=force_strategy,
confidence=1.0,
reasoning="Forced by user",
features={}
)
else:
classification = self.classifier.classify(query)
# Get strategy and execute
strategy = self.get_strategy(classification.level)
result = await strategy.execute(query)
return result, classification
def add_documents(self, documents: list[str], sources: list[str]):
"""Add documents to retrieval strategies that need them."""
# Only SingleStep and MultiStep need documents
single_step = self.strategies[ComplexityLevel.MODERATE]
if hasattr(single_step, 'add_documents'):
single_step.add_documents(documents, sources)Step 5: Main Orchestration
# adaptive_rag.py
from pydantic import BaseModel
from router import StrategyRouter
from strategies.base import StrategyResult
from complexity_classifier import ComplexityPrediction
from config import ComplexityLevel
class AdaptiveRAGResponse(BaseModel):
"""Response from Adaptive RAG."""
answer: str
sources: list[dict]
classification: ComplexityPrediction
strategy_used: str
retrieval_steps: int
latency_ms: float
tokens_used: int | None
class AdaptiveRAG:
"""Adaptive RAG system with query complexity routing."""
def __init__(self):
self.router = StrategyRouter()
async def query(
self,
question: str,
force_strategy: str | None = None
) -> AdaptiveRAGResponse:
"""
Process a query with adaptive strategy selection.
Args:
question: User question
force_strategy: Optional override ("simple", "moderate", "complex")
"""
# Parse force_strategy if provided
forced = None
if force_strategy:
forced = ComplexityLevel(force_strategy)
# Route and execute
result, classification = await self.router.route_and_execute(
query=question,
force_strategy=forced
)
return AdaptiveRAGResponse(
answer=result.answer,
sources=result.sources,
classification=classification,
strategy_used=result.strategy_used,
retrieval_steps=result.retrieval_steps,
latency_ms=result.latency_ms,
tokens_used=result.tokens_used
)
def add_documents(self, documents: list[str], sources: list[str]):
"""Add documents to the knowledge base."""
self.router.add_documents(documents, sources)Step 6: FastAPI Application
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from contextlib import asynccontextmanager
from typing import Literal
from adaptive_rag import AdaptiveRAG, AdaptiveRAGResponse
from config import ComplexityLevel
# Global
adaptive_rag: AdaptiveRAG | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global adaptive_rag
adaptive_rag = AdaptiveRAG()
# Add sample documents
sample_docs = [
"Python is a high-level, interpreted programming language known for its simplicity and readability. It was created by Guido van Rossum and first released in 1991.",
"Python decorators are a powerful feature that allows you to modify the behavior of functions or classes. They use the @decorator syntax and are commonly used for logging, authentication, and caching.",
"Python 3.12 introduced several new features including improved error messages, a new type parameter syntax for generic classes, and performance improvements of up to 5%.",
"Async programming in Python uses the asyncio library with async/await syntax. It enables concurrent execution of I/O-bound operations without threads.",
"Go (Golang) uses goroutines for concurrency - lightweight threads managed by the Go runtime. Goroutines communicate through channels, following the CSP model.",
"Rust handles concurrency through its ownership system, preventing data races at compile time. It offers both message passing and shared-state concurrency.",
"PostgreSQL is an object-relational database known for ACID compliance, extensibility, and support for complex queries. It excels at write-heavy workloads.",
"MongoDB is a document database that stores data in flexible JSON-like documents. It's designed for horizontal scaling and high availability.",
"For time-series data, PostgreSQL with TimescaleDB extension provides automatic partitioning, continuous aggregates, and compression. MongoDB requires manual sharding for time-series."
]
sources = [
"python_intro", "python_decorators", "python_312", "python_async",
"go_concurrency", "rust_concurrency",
"postgresql_overview", "mongodb_overview", "timeseries_comparison"
]
adaptive_rag.add_documents(sample_docs, sources)
yield
adaptive_rag = None
app = FastAPI(
title="Adaptive RAG API",
description="Query complexity-aware RAG with adaptive retrieval strategies",
lifespan=lifespan
)
class QueryRequest(BaseModel):
query: str
force_strategy: Literal["simple", "moderate", "complex"] | None = None
class DocumentsRequest(BaseModel):
documents: list[str]
sources: list[str]
@app.post("/query", response_model=AdaptiveRAGResponse)
async def query(request: QueryRequest):
"""Query with adaptive strategy selection."""
if not adaptive_rag:
raise HTTPException(status_code=503, detail="Service not initialized")
result = await adaptive_rag.query(
question=request.query,
force_strategy=request.force_strategy
)
return result
@app.post("/documents")
async def add_documents(request: DocumentsRequest):
"""Add documents to the knowledge base."""
if not adaptive_rag:
raise HTTPException(status_code=503, detail="Service not initialized")
if len(request.documents) != len(request.sources):
raise HTTPException(
status_code=400,
detail="Documents and sources must have same length"
)
adaptive_rag.add_documents(request.documents, request.sources)
return {"status": "success", "documents_added": len(request.documents)}
@app.get("/strategies")
async def list_strategies():
"""List available strategies and their descriptions."""
return {
"strategies": {
"simple": "No retrieval - direct LLM answer for basic questions",
"moderate": "Single-step RAG - standard retrieval for specific questions",
"complex": "Multi-step RAG - query decomposition for complex questions"
}
}
@app.get("/health")
async def health():
return {"status": "healthy", "service": "adaptive-rag"}Step 7: Requirements
# requirements.txt
openai>=1.12.0
chromadb>=0.4.22
pydantic>=2.0.0
pydantic-settings>=2.0.0
fastapi>=0.109.0
uvicorn>=0.27.0
scikit-learn>=1.4.0
joblib>=1.3.0
python-dotenv>=1.0.0Usage Examples
Basic Usage
from adaptive_rag import AdaptiveRAG
import asyncio
async def main():
rag = AdaptiveRAG()
# Add documents
rag.add_documents(
documents=["Your content here..."],
sources=["source_name"]
)
# Simple query - no retrieval
result = await rag.query("What is Python?")
print(f"Strategy: {result.strategy_used}") # "no_retrieval"
print(f"Answer: {result.answer}")
# Moderate query - single-step
result = await rag.query("What are Python decorators?")
print(f"Strategy: {result.strategy_used}") # "single_step"
# Complex query - multi-step
result = await rag.query(
"Compare Python async with Go goroutines for web scraping"
)
print(f"Strategy: {result.strategy_used}") # "multi_step"
print(f"Retrieval steps: {result.retrieval_steps}")
asyncio.run(main())Force Strategy
# Override classifier decision
result = await rag.query(
question="What is Python?",
force_strategy="moderate" # Force retrieval even for simple query
)API Usage
# Start server
uvicorn app:app --reload
# Query (auto-classified)
curl -X POST http://localhost:8000/query \
-H "Content-Type: application/json" \
-d '{"query": "Compare PostgreSQL vs MongoDB for time-series data"}'
# Force specific strategy
curl -X POST http://localhost:8000/query \
-H "Content-Type: application/json" \
-d '{"query": "What is Python?", "force_strategy": "moderate"}'How Adaptive RAG Saves Compute
┌─────────────────────────────────────────────────────────────────┐
│ COMPUTE COST BY STRATEGY │
│ │
│ Strategy Tokens Visual │
│ ───────────────────────────────────────────────────────────── │
│ Simple ~500 ██ │
│ Moderate ~2000 ████████ │
│ Complex ~5000 ████████████████████ │
│ │
│ Traditional always uses moderate cost even for simple queries! │
└─────────────────────────────────────────────────────────────────┘| Query Type | Traditional RAG | Adaptive RAG | Savings |
|---|---|---|---|
| Simple (40%) | 2000 tokens | 500 tokens | 75% |
| Moderate (45%) | 2000 tokens | 2000 tokens | 0% |
| Complex (15%) | 2000 tokens | 5000 tokens | -150% |
| Weighted Avg | 2000 tokens | 1475 tokens | 26% |
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Query Complexity | Simple/Moderate/Complex classification | Different queries need different strategies |
| No Retrieval | Direct LLM answer | Saves tokens for questions LLM already knows |
| Single-Step | Standard RAG | Right-sized for most factual questions |
| Multi-Step | Decompose → Retrieve each → Synthesize | Handles comparisons and multi-part queries |
| Hybrid Classifier | Rules first, LLM verification | Fast when easy, accurate when hard |
| Query Decomposition | Break complex into sub-queries | Each sub-query retrieves focused context |
| Compute Optimization | Match strategy to complexity | ~26% average token savings |
References
- Adaptive-RAG Paper (arxiv 2403.14403)
- Self-RAG for self-correction techniques
- Multi-Document RAG for context management
Next Steps
- Add classifier training with labeled examples
- Implement confidence-based fallback (low confidence → try multiple strategies)
- Build strategy performance tracking for continuous improvement
- Explore Speculative RAG for parallel drafting