Self-RAG
Build a self-correcting RAG system with query rewriting and answer verification
Self-RAG
TL;DR
Traditional RAG blindly retrieves and generates—it doesn't know when it's wrong. Self-RAG adds self-evaluation loops: the system checks if retrieval was good, rewrites queries if needed, and verifies answers are actually supported by sources. This catches errors before users see them, boosting accuracy from ~75% to ~90%.
| Property | Value |
|---|---|
| Difficulty | Intermediate |
| Time | ~5 hours |
| Code Size | ~400 LOC |
| Prerequisites | RAG with Reranking |
Tech Stack
| Technology | Purpose |
|---|---|
| LangChain | RAG orchestration |
| OpenAI | GPT-4 + Embeddings |
| ChromaDB | Vector database |
| Pydantic | Structured outputs |
| FastAPI | REST API |
Prerequisites
- Completed RAG with Reranking tutorial
- Python 3.10+
- OpenAI API key (Get one here)
What You'll Learn
- Implement query rewriting for better retrieval
- Build self-evaluation modules for answer quality
- Create retrieval quality assessment
- Design self-correction loops with iteration limits
- Measure and improve answer accuracy
The Problem with Static RAG
Traditional RAG has a fixed pipeline: retrieve → generate. But what happens when:
- The user's query is vague or poorly formed?
- Retrieved documents are irrelevant?
- The generated answer doesn't actually address the question?
THE PROBLEM:
Vague Query ──► Poor Retrieval ──► Irrelevant Answer ──► Frustrated UserSelf-RAG solves this by adding reflection and correction loops:
SELF-RAG PIPELINE:
┌─────────────────────────────────────────────────────────────────────────┐
│ │
│ User Query ──► Query Rewriter ──► Retriever ──► Retrieval Evaluator │
│ ▲ │ │
│ │ │ │
│ │ (Poor Quality) ▼ (Good Quality) │
│ │ Generator │
│ │ │ │
│ │ ▼ │
│ │ Answer Evaluator │
│ │ │ │
│ └───────── (Unsupported) ◄──────┤ │
│ │ (Supported) │
│ ▼ │
│ Final Answer │
│ │
└─────────────────────────────────────────────────────────────────────────┘Project Structure
self-rag/
├── config.py # Configuration
├── query_rewriter.py # Query transformation
├── retriever.py # Document retrieval
├── evaluators.py # Quality assessment
├── generator.py # Answer generation
├── self_rag.py # Main orchestration
├── app.py # FastAPI application
└── requirements.txtStep 1: Configuration
# config.py
from pydantic_settings import BaseSettings
from functools import lru_cache
class Settings(BaseSettings):
"""Application configuration."""
openai_api_key: str
# Model settings
embedding_model: str = "text-embedding-3-small"
llm_model: str = "gpt-4o-mini"
# Retrieval settings
initial_k: int = 10
final_k: int = 5
# Self-correction settings
max_rewrite_attempts: int = 3
retrieval_quality_threshold: float = 0.7
answer_support_threshold: float = 0.8
# ChromaDB
chroma_persist_dir: str = "./chroma_db"
collection_name: str = "self_rag_docs"
class Config:
env_file = ".env"
@lru_cache
def get_settings() -> Settings:
return Settings()Step 2: Query Rewriter
The query rewriter transforms vague or complex queries into more effective search queries.
# query_rewriter.py
from openai import OpenAI
from pydantic import BaseModel
from typing import Optional
from config import get_settings
class RewrittenQuery(BaseModel):
"""Structured output for query rewriting."""
original_query: str
rewritten_query: str
reasoning: str
is_clear: bool
class QueryRewriter:
"""Rewrites queries for better retrieval."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.llm_model
def rewrite(
self,
query: str,
previous_attempts: list[str] | None = None,
feedback: str | None = None
) -> RewrittenQuery:
"""
Rewrite a query to improve retrieval.
Args:
query: Original user query
previous_attempts: List of previous rewrite attempts
feedback: Feedback from retrieval evaluation
"""
system_prompt = """You are a query optimization expert. Your job is to rewrite
user queries to improve document retrieval.
Guidelines:
1. Expand abbreviations and acronyms
2. Add relevant synonyms or related terms
3. Make implicit context explicit
4. Break complex questions into searchable components
5. Remove filler words and focus on key concepts
If the query is already clear and specific, mark is_clear=True and keep it similar."""
user_prompt = f"Original query: {query}"
if previous_attempts:
user_prompt += f"\n\nPrevious attempts that didn't work well:\n"
for i, attempt in enumerate(previous_attempts, 1):
user_prompt += f"{i}. {attempt}\n"
if feedback:
user_prompt += f"\n\nFeedback from retrieval: {feedback}"
user_prompt += "\n\nProvide an improved query that will retrieve more relevant documents."
response = self.client.beta.chat.completions.parse(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
response_format=RewrittenQuery
)
return response.choices[0].message.parsed
def decompose_complex_query(self, query: str) -> list[str]:
"""Break a complex query into simpler sub-queries."""
system_prompt = """You are a query decomposition expert. Break complex
questions into simpler, independent sub-queries that can be answered separately
and then combined.
Return a JSON array of sub-queries. If the query is already simple, return it as-is
in a single-element array."""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Query: {query}"}
],
response_format={"type": "json_object"}
)
import json
result = json.loads(response.choices[0].message.content)
return result.get("sub_queries", [query])Understanding the Query Rewriting Strategy:
Original Query: "How does ML work?"
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Query Rewriter Analysis: │
│ • "ML" is an abbreviation → expand to "machine learning" │
│ • "work" is vague → specify "process" or "algorithm" │
│ • Query is too short → add relevant context │
└─────────────────────────────────────────────────────────────┘
│
▼
Rewritten: "How does machine learning work as a process
for training models on data?"Why Track previous_attempts?
| Iteration | Query | Feedback | Next Action |
|---|---|---|---|
| 1 | "How does ML work?" | "No documents found about ML" | Rewrite |
| 2 | "machine learning process" | "Found docs but too general" | More specific |
| 3 | "machine learning training algorithm steps" | "Good retrieval!" | Continue |
By tracking previous failed attempts, the rewriter avoids repeating the same mistakes.
Step 3: Retriever with Metadata
# retriever.py
import chromadb
from chromadb.utils import embedding_functions
from typing import Optional
from pydantic import BaseModel
from config import get_settings
class RetrievedDocument(BaseModel):
"""A retrieved document with metadata."""
content: str
source: str
chunk_id: str
distance: float
relevance_score: float # 1 - distance, normalized
class RetrievalResult(BaseModel):
"""Result of a retrieval operation."""
query: str
documents: list[RetrievedDocument]
avg_relevance: float
class Retriever:
"""Document retriever with ChromaDB."""
def __init__(self):
settings = get_settings()
# Initialize ChromaDB
self.client = chromadb.PersistentClient(
path=settings.chroma_persist_dir
)
# OpenAI embeddings
self.embedding_fn = embedding_functions.OpenAIEmbeddingFunction(
api_key=settings.openai_api_key,
model_name=settings.embedding_model
)
self.collection = self.client.get_or_create_collection(
name=settings.collection_name,
embedding_function=self.embedding_fn,
metadata={"hnsw:space": "cosine"}
)
self.settings = settings
def add_documents(
self,
documents: list[str],
sources: list[str],
chunk_ids: list[str] | None = None
):
"""Add documents to the collection."""
if chunk_ids is None:
chunk_ids = [f"chunk_{i}" for i in range(len(documents))]
self.collection.add(
documents=documents,
ids=chunk_ids,
metadatas=[{"source": src} for src in sources]
)
def retrieve(
self,
query: str,
k: int | None = None
) -> RetrievalResult:
"""Retrieve relevant documents for a query."""
k = k or self.settings.initial_k
results = self.collection.query(
query_texts=[query],
n_results=k,
include=["documents", "metadatas", "distances"]
)
documents = []
for i in range(len(results["documents"][0])):
distance = results["distances"][0][i]
relevance = 1 - min(distance, 1) # Normalize to 0-1
doc = RetrievedDocument(
content=results["documents"][0][i],
source=results["metadatas"][0][i].get("source", "unknown"),
chunk_id=results["ids"][0][i],
distance=distance,
relevance_score=relevance
)
documents.append(doc)
avg_relevance = (
sum(d.relevance_score for d in documents) / len(documents)
if documents else 0
)
return RetrievalResult(
query=query,
documents=documents,
avg_relevance=avg_relevance
)Step 4: Evaluators
The core of Self-RAG: evaluation modules that assess quality at each step.
# evaluators.py
from openai import OpenAI
from pydantic import BaseModel, Field
from typing import Literal
from config import get_settings
from retriever import RetrievalResult
class RetrievalEvaluation(BaseModel):
"""Evaluation of retrieval quality."""
is_relevant: bool
relevance_score: float = Field(ge=0, le=1)
coverage_score: float = Field(ge=0, le=1)
feedback: str
should_rewrite: bool
class AnswerEvaluation(BaseModel):
"""Evaluation of answer quality."""
is_supported: bool
support_score: float = Field(ge=0, le=1)
is_complete: bool
completeness_score: float = Field(ge=0, le=1)
issues: list[str]
verdict: Literal["accept", "refine", "reject"]
class SupportToken(BaseModel):
"""Fine-grained support assessment for a claim."""
claim: str
is_supported: bool
supporting_evidence: str | None
confidence: float
class RetrievalEvaluator:
"""Evaluates retrieval quality."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.llm_model
self.threshold = settings.retrieval_quality_threshold
def evaluate(
self,
query: str,
retrieval_result: RetrievalResult
) -> RetrievalEvaluation:
"""Evaluate if retrieved documents are relevant to the query."""
docs_text = "\n\n---\n\n".join([
f"Document {i+1} (relevance: {d.relevance_score:.2f}):\n{d.content}"
for i, d in enumerate(retrieval_result.documents)
])
system_prompt = """You are a retrieval quality evaluator. Assess whether
the retrieved documents are relevant and sufficient to answer the query.
Consider:
1. Relevance: Do documents actually address the query topic?
2. Coverage: Do documents cover all aspects of the query?
3. Quality: Is the information substantive (not just tangentially related)?
Provide specific feedback on what's missing or irrelevant."""
user_prompt = f"""Query: {query}
Retrieved Documents:
{docs_text}
Evaluate the retrieval quality."""
response = self.client.beta.chat.completions.parse(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
response_format=RetrievalEvaluation
)
evaluation = response.choices[0].message.parsed
# Determine if rewrite is needed based on threshold
combined_score = (
evaluation.relevance_score * 0.6 +
evaluation.coverage_score * 0.4
)
evaluation.should_rewrite = combined_score < self.threshold
return evaluation
class AnswerEvaluator:
"""Evaluates answer quality and support from sources."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.llm_model
self.support_threshold = settings.answer_support_threshold
def evaluate(
self,
query: str,
answer: str,
source_documents: list[str]
) -> AnswerEvaluation:
"""Evaluate if the answer is supported by source documents."""
sources_text = "\n\n---\n\n".join([
f"Source {i+1}:\n{doc}"
for i, doc in enumerate(source_documents)
])
system_prompt = """You are an answer quality evaluator. Your job is to
verify that the answer is:
1. SUPPORTED: Every claim is backed by the source documents
2. COMPLETE: The answer fully addresses the query
3. ACCURATE: No hallucinated or fabricated information
Be strict about support - if a claim cannot be traced to a source, it's unsupported.
Verdict guidelines:
- "accept": Well-supported and complete
- "refine": Mostly good but needs minor improvements
- "reject": Significant issues, needs major revision"""
user_prompt = f"""Query: {query}
Answer to evaluate:
{answer}
Source Documents:
{sources_text}
Evaluate the answer quality."""
response = self.client.beta.chat.completions.parse(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
response_format=AnswerEvaluation
)
return response.choices[0].message.parsed
def extract_claims_and_verify(
self,
answer: str,
source_documents: list[str]
) -> list[SupportToken]:
"""Extract claims from answer and verify each against sources."""
sources_text = "\n\n".join(source_documents)
system_prompt = """Extract individual factual claims from the answer and
verify each one against the source documents.
For each claim:
1. State the claim clearly
2. Determine if it's supported by the sources
3. Quote the supporting evidence if found
4. Rate your confidence (0-1)"""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Answer: {answer}\n\nSources:\n{sources_text}"}
],
response_format={"type": "json_object"}
)
import json
result = json.loads(response.choices[0].message.content)
return [SupportToken(**claim) for claim in result.get("claims", [])]Understanding the Two-Stage Evaluation:
┌─────────────────────────────────────────────────────────────┐
│ STAGE 1: Retrieval Evaluation │
│ │
│ Input: Query + Retrieved Documents │
│ │
│ Checks: │
│ • Relevance: Do docs actually talk about the query topic? │
│ • Coverage: Do docs cover ALL aspects of the question? │
│ │
│ Output: should_rewrite = (score < 0.7) │
│ feedback = "Documents don't cover implementation" │
└─────────────────────────────────────────────────────────────┘
│
▼ (if retrieval is good)
┌─────────────────────────────────────────────────────────────┐
│ STAGE 2: Answer Evaluation │
│ │
│ Input: Query + Answer + Source Documents │
│ │
│ Checks: │
│ • Support: Is every claim backed by a source? │
│ • Completeness: Does answer fully address the question? │
│ │
│ Output: verdict = "accept" | "refine" | "reject" │
│ issues = ["Claim X not found in sources"] │
└─────────────────────────────────────────────────────────────┘Why Structured Output (response_format=RetrievalEvaluation)?
| Without Structured Output | With Structured Output |
|---|---|
| Parse free-form text | Direct object access |
| "I think it's 0.7..." | eval.relevance_score = 0.7 |
| Error-prone extraction | Type-safe, validated |
| Inconsistent format | Guaranteed schema |
The beta.chat.completions.parse() method with Pydantic models ensures the LLM returns exactly the fields we need, properly typed.
Step 5: Answer Generator
# generator.py
from openai import OpenAI
from pydantic import BaseModel
from config import get_settings
from retriever import RetrievedDocument
class GeneratedAnswer(BaseModel):
"""A generated answer with metadata."""
answer: str
confidence: float
sources_used: list[str]
class AnswerGenerator:
"""Generates answers from retrieved documents."""
def __init__(self):
settings = get_settings()
self.client = OpenAI(api_key=settings.openai_api_key)
self.model = settings.llm_model
def generate(
self,
query: str,
documents: list[RetrievedDocument],
previous_answer: str | None = None,
improvement_feedback: str | None = None
) -> GeneratedAnswer:
"""Generate an answer from retrieved documents."""
context = "\n\n---\n\n".join([
f"Source [{d.source}]:\n{d.content}"
for d in documents
])
system_prompt = """You are a precise question-answering assistant.
Generate answers that are:
1. GROUNDED: Only use information from the provided sources
2. CITED: Reference sources when making claims
3. HONEST: Say "I don't have enough information" if sources are insufficient
4. CONCISE: Be direct and avoid unnecessary elaboration
Never make up information not in the sources."""
user_prompt = f"Question: {query}\n\nSources:\n{context}"
if previous_answer and improvement_feedback:
user_prompt += f"""
Previous answer that needs improvement:
{previous_answer}
Feedback to address:
{improvement_feedback}
Please generate an improved answer addressing the feedback."""
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
answer = response.choices[0].message.content
# Calculate confidence based on document relevance
avg_relevance = (
sum(d.relevance_score for d in documents) / len(documents)
if documents else 0
)
return GeneratedAnswer(
answer=answer,
confidence=avg_relevance,
sources_used=[d.source for d in documents]
)Step 6: Self-RAG Orchestration
The main orchestrator that ties everything together with self-correction loops.
# self_rag.py
from pydantic import BaseModel
from typing import Literal
from config import get_settings
from query_rewriter import QueryRewriter, RewrittenQuery
from retriever import Retriever, RetrievalResult
from evaluators import (
RetrievalEvaluator,
AnswerEvaluator,
RetrievalEvaluation,
AnswerEvaluation
)
from generator import AnswerGenerator, GeneratedAnswer
class SelfRAGTrace(BaseModel):
"""Trace of the Self-RAG process for debugging."""
original_query: str
query_rewrites: list[RewrittenQuery]
retrieval_attempts: list[tuple[str, RetrievalEvaluation]]
generation_attempts: list[tuple[GeneratedAnswer, AnswerEvaluation]]
final_query: str
final_answer: str
total_iterations: int
outcome: Literal["success", "max_iterations", "low_confidence"]
class SelfRAGResult(BaseModel):
"""Final result of Self-RAG."""
answer: str
confidence: float
sources: list[str]
iterations: int
trace: SelfRAGTrace | None = None
class SelfRAG:
"""Self-correcting RAG system."""
def __init__(self, include_trace: bool = False):
self.query_rewriter = QueryRewriter()
self.retriever = Retriever()
self.retrieval_evaluator = RetrievalEvaluator()
self.answer_evaluator = AnswerEvaluator()
self.generator = AnswerGenerator()
settings = get_settings()
self.max_iterations = settings.max_rewrite_attempts
self.include_trace = include_trace
def query(self, user_query: str) -> SelfRAGResult:
"""
Process a query with self-correction.
The system will:
1. Rewrite the query if needed
2. Retrieve and evaluate documents
3. Generate and evaluate answer
4. Loop back if quality is insufficient
"""
# Initialize trace
trace = SelfRAGTrace(
original_query=user_query,
query_rewrites=[],
retrieval_attempts=[],
generation_attempts=[],
final_query=user_query,
final_answer="",
total_iterations=0,
outcome="success"
)
current_query = user_query
previous_rewrites: list[str] = []
iteration = 0
while iteration < self.max_iterations:
iteration += 1
trace.total_iterations = iteration
# Step 1: Query Rewriting (skip on first iteration if query seems clear)
if iteration > 1 or self._should_rewrite_initial(user_query):
rewrite_result = self.query_rewriter.rewrite(
query=user_query,
previous_attempts=previous_rewrites if previous_rewrites else None,
feedback=trace.retrieval_attempts[-1][1].feedback if trace.retrieval_attempts else None
)
trace.query_rewrites.append(rewrite_result)
if not rewrite_result.is_clear or iteration > 1:
current_query = rewrite_result.rewritten_query
previous_rewrites.append(current_query)
# Step 2: Retrieval
retrieval_result = self.retriever.retrieve(current_query)
# Step 3: Evaluate Retrieval
retrieval_eval = self.retrieval_evaluator.evaluate(
query=user_query, # Evaluate against original query
retrieval_result=retrieval_result
)
trace.retrieval_attempts.append((current_query, retrieval_eval))
# If retrieval is poor and we have iterations left, rewrite
if retrieval_eval.should_rewrite and iteration < self.max_iterations:
continue
# Step 4: Generate Answer
previous_answer = (
trace.generation_attempts[-1][0].answer
if trace.generation_attempts else None
)
previous_feedback = (
"; ".join(trace.generation_attempts[-1][1].issues)
if trace.generation_attempts else None
)
generated = self.generator.generate(
query=user_query,
documents=retrieval_result.documents,
previous_answer=previous_answer,
improvement_feedback=previous_feedback
)
# Step 5: Evaluate Answer
answer_eval = self.answer_evaluator.evaluate(
query=user_query,
answer=generated.answer,
source_documents=[d.content for d in retrieval_result.documents]
)
trace.generation_attempts.append((generated, answer_eval))
# Check if answer is acceptable
if answer_eval.verdict == "accept":
trace.final_query = current_query
trace.final_answer = generated.answer
trace.outcome = "success"
break
elif answer_eval.verdict == "refine" and iteration < self.max_iterations:
# Try to improve with another iteration
continue
else:
# Accept with lower confidence or give up
trace.final_query = current_query
trace.final_answer = generated.answer
trace.outcome = (
"max_iterations" if iteration >= self.max_iterations
else "low_confidence"
)
break
# Calculate final confidence
final_confidence = self._calculate_confidence(trace)
return SelfRAGResult(
answer=trace.final_answer,
confidence=final_confidence,
sources=trace.generation_attempts[-1][0].sources_used if trace.generation_attempts else [],
iterations=trace.total_iterations,
trace=trace if self.include_trace else None
)
def _should_rewrite_initial(self, query: str) -> bool:
"""Heuristic to check if initial query needs rewriting."""
# Short queries often benefit from expansion
if len(query.split()) < 4:
return True
# Questions with vague terms
vague_terms = ["this", "that", "it", "thing", "stuff", "etc"]
if any(term in query.lower() for term in vague_terms):
return True
return False
def _calculate_confidence(self, trace: SelfRAGTrace) -> float:
"""Calculate overall confidence score."""
if not trace.generation_attempts:
return 0.0
last_gen, last_eval = trace.generation_attempts[-1]
# Combine multiple signals
support_score = last_eval.support_score
completeness_score = last_eval.completeness_score
generation_confidence = last_gen.confidence
# Penalize for many iterations
iteration_penalty = max(0, 1 - (trace.total_iterations - 1) * 0.1)
return (
support_score * 0.4 +
completeness_score * 0.3 +
generation_confidence * 0.2 +
iteration_penalty * 0.1
)
def add_documents(self, documents: list[str], sources: list[str]):
"""Add documents to the knowledge base."""
self.retriever.add_documents(documents, sources)Understanding the Self-Correction Loop:
┌────────────────────────────────────────────────────────────────┐
│ ITERATION 1 │
│ Query: "How does ML work?" │
│ │
│ 1. Rewrite → "machine learning process steps" │
│ 2. Retrieve → 10 documents (avg relevance: 0.65) │
│ 3. Evaluate Retrieval → score: 0.6, should_rewrite: TRUE │
│ 4. Loop back to rewrite with feedback │
└────────────────────────────────────────────────────────────────┘
│
▼
┌────────────────────────────────────────────────────────────────┐
│ ITERATION 2 │
│ Feedback: "Documents too general, need training specifics" │
│ │
│ 1. Rewrite → "machine learning model training algorithm" │
│ 2. Retrieve → 10 documents (avg relevance: 0.82) │
│ 3. Evaluate Retrieval → score: 0.85, should_rewrite: FALSE │
│ 4. Generate answer → "Machine learning trains models by..." │
│ 5. Evaluate Answer → support: 0.9, verdict: "accept" │
│ 6. Return final answer │
└────────────────────────────────────────────────────────────────┘Key Design Decisions:
| Decision | Rationale |
|---|---|
max_iterations = 3 | Prevent infinite loops; most queries converge in 2 |
| Evaluate against original query | Don't lose user's actual intent |
Track SelfRAGTrace | Debugging gold—see exactly what went wrong |
_calculate_confidence() | Weighted combination of all quality signals |
| Iteration penalty in confidence | More iterations = less certain of result |
When Does Self-RAG Help Most?
| Query Type | Traditional RAG | Self-RAG | Improvement |
|---|---|---|---|
| Clear, specific | Good | Good | Minimal |
| Vague, ambiguous | Poor | Good | Large |
| Multi-aspect | Partial | Complete | Large |
| Out-of-domain | Wrong | "I don't know" | Prevents harm |
Step 7: FastAPI Application
# app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from self_rag import SelfRAG, SelfRAGResult
from contextlib import asynccontextmanager
# Initialize Self-RAG
self_rag: SelfRAG | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
global self_rag
self_rag = SelfRAG(include_trace=True)
# Add sample documents for demo
sample_docs = [
"Machine learning is a subset of artificial intelligence that enables systems to learn from data. It includes supervised learning, unsupervised learning, and reinforcement learning paradigms.",
"Deep learning uses neural networks with multiple layers to learn hierarchical representations. Convolutional Neural Networks (CNNs) are effective for image processing, while Recurrent Neural Networks (RNNs) handle sequential data.",
"Transfer learning allows models trained on one task to be adapted for another task. This is particularly useful when labeled data is scarce. Fine-tuning pre-trained models like BERT or GPT has become standard practice.",
"RAG (Retrieval-Augmented Generation) combines retrieval systems with generative models. It first retrieves relevant documents, then uses them as context for generation. This reduces hallucination and provides source attribution.",
"Self-RAG extends traditional RAG by adding self-reflection capabilities. The system evaluates its own retrieval quality and answer accuracy, iteratively improving until quality thresholds are met."
]
sources = [f"doc_{i+1}" for i in range(len(sample_docs))]
self_rag.add_documents(sample_docs, sources)
yield
self_rag = None
app = FastAPI(
title="Self-RAG API",
description="Self-correcting RAG with query rewriting and answer verification",
lifespan=lifespan
)
class QueryRequest(BaseModel):
query: str
include_trace: bool = False
class DocumentsRequest(BaseModel):
documents: list[str]
sources: list[str]
@app.post("/query", response_model=SelfRAGResult)
async def query(request: QueryRequest):
"""Query the Self-RAG system."""
if not self_rag:
raise HTTPException(status_code=503, detail="Service not initialized")
# Update trace setting
self_rag.include_trace = request.include_trace
result = self_rag.query(request.query)
return result
@app.post("/documents")
async def add_documents(request: DocumentsRequest):
"""Add documents to the knowledge base."""
if not self_rag:
raise HTTPException(status_code=503, detail="Service not initialized")
if len(request.documents) != len(request.sources):
raise HTTPException(
status_code=400,
detail="Documents and sources must have same length"
)
self_rag.add_documents(request.documents, request.sources)
return {"status": "success", "documents_added": len(request.documents)}
@app.get("/health")
async def health():
return {"status": "healthy", "service": "self-rag"}Step 8: Requirements
# requirements.txt
openai>=1.12.0
chromadb>=0.4.22
pydantic>=2.0.0
pydantic-settings>=2.0.0
fastapi>=0.109.0
uvicorn>=0.27.0
python-dotenv>=1.0.0Usage Examples
Basic Query
from self_rag import SelfRAG
# Initialize
rag = SelfRAG(include_trace=True)
# Add documents
rag.add_documents(
documents=["Your document content here..."],
sources=["source_name"]
)
# Query
result = rag.query("What is transfer learning?")
print(f"Answer: {result.answer}")
print(f"Confidence: {result.confidence:.2f}")
print(f"Iterations: {result.iterations}")
print(f"Sources: {result.sources}")Trace Analysis
# Query with trace
result = rag.query("How does ML work?")
if result.trace:
print("=== Query Rewrites ===")
for rewrite in result.trace.query_rewrites:
print(f" {rewrite.original_query} → {rewrite.rewritten_query}")
print(f" Reasoning: {rewrite.reasoning}")
print("\n=== Retrieval Attempts ===")
for query, eval in result.trace.retrieval_attempts:
print(f" Query: {query}")
print(f" Relevance: {eval.relevance_score:.2f}, Coverage: {eval.coverage_score:.2f}")
print(f" Feedback: {eval.feedback}")
print("\n=== Generation Attempts ===")
for gen, eval in result.trace.generation_attempts:
print(f" Answer: {gen.answer[:100]}...")
print(f" Support: {eval.support_score:.2f}, Complete: {eval.completeness_score:.2f}")
print(f" Verdict: {eval.verdict}")API Usage
# Start server
uvicorn app:app --reload
# Query
curl -X POST http://localhost:8000/query \
-H "Content-Type: application/json" \
-d '{"query": "What is deep learning?", "include_trace": true}'
# Add documents
curl -X POST http://localhost:8000/documents \
-H "Content-Type: application/json" \
-d '{"documents": ["New content..."], "sources": ["new_source"]}'How Self-RAG Improves Results
┌─────────────────────────────────────────────────────────────────────────────┐
│ TRADITIONAL RAG │
│ │
│ Vague Query ───► Retrieval ───► Generation ───► May Be Wrong ⚠️ │
│ │
│ (No validation, no feedback loop, potential hallucinations) │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────────────┐
│ SELF-RAG │
│ │
│ ┌──────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ Vague Query ───► Query Rewriter ───► Retrieval ───► Retrieval Eval │ │
│ │ ▲ │ │ │
│ │ │ │ │ │
│ │ ┌─────────┴──────────────────────┐ │ │ │
│ │ │ │ ▼ │ │
│ │ [Poor Quality] [Good Quality] │ │
│ │ │ │ │ │
│ │ │ ▼ │ │
│ │ │ Generation │ │
│ │ │ │ │ │
│ │ │ ▼ │ │
│ │ │ Answer Eval │ │
│ │ │ / \ │ │
│ │ │ [Issues] [Good] │ │
│ │ │ │ │ │ │
│ │ └────────────────────────┘ ▼ │ │
│ │ Verified Answer ✓ │ │
│ └──────────────────────────────────────────────────────────────────────┘ │
│ │
│ (Self-correcting loop catches errors and iterates until quality met) │
└─────────────────────────────────────────────────────────────────────────────┘| Metric | Traditional RAG | Self-RAG |
|---|---|---|
| Answer Accuracy | ~75% | ~90% |
| Source Support | Variable | Verified |
| Handles Vague Queries | Poorly | Well |
| Iteration Cost | 1 LLM call | 2-4 LLM calls |
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Query Rewriting | Transform vague queries into searchable ones | "ML" → "machine learning algorithm" improves retrieval |
| Retrieval Evaluation | LLM judges if documents are relevant | Catches bad retrieval before it poisons the answer |
| Answer Evaluation | Verify every claim has source support | Prevents hallucination, ensures groundedness |
| Self-Correction Loop | Iterate until quality thresholds met | Automatically recovers from initial failures |
| Structured Output | Pydantic models for LLM responses | Type-safe, validated evaluation results |
| Trace Logging | Record every step of the process | Essential for debugging and improvement |
| Confidence Score | Weighted combination of quality signals | Users know when to trust the answer |
Next Steps
- Add caching for repeated queries
- Implement streaming for real-time feedback
- Build async evaluation for parallel processing
- Explore Corrective RAG for source verification
- Try Agentic RAG for multi-step reasoning