Custom Reranker
Train a cross-encoder reranker for RAG systems
Custom Reranker
TL;DR
Bi-encoders retrieve fast (O(1) per query via pre-computed embeddings) but miss nuanced relevance. Cross-encoders jointly process query+document pairs for higher accuracy but are slower (O(n)). Solution: two-stage retrieval - bi-encoder fetches top-100, cross-encoder reranks to top-10.
Build and train a cross-encoder model to improve RAG retrieval quality through reranking.
Overview
| Difficulty | Intermediate |
| Time | ~4 hours |
| Prerequisites | PyTorch, Transformers, RAG basics |
| Learning Outcomes | Cross-encoder training, ranking losses, hard negative mining |
Introduction
RAG systems use a two-stage retrieval approach:
- Stage 1 (Retrieval): Bi-encoders retrieve top-k candidates quickly
- Stage 2 (Reranking): Cross-encoders rerank candidates for precision
Cross-encoders achieve higher accuracy by jointly processing query-document pairs, enabling deep semantic understanding.
┌─────────────────────────────────────────────────────────────────────────────┐
│ Two-Stage Retrieval Pipeline │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Query │───►│ Bi-Encoder │───►│ Top-100 │───►│Cross-Encoder│ │
│ │ │ │Fast Retriev.│ │ Candidates │ │ Reranking │ │
│ └─────────┘ └─────────────┘ └─────────────┘ └──────┬──────┘ │
│ │ │ │
│ │ ▼ │
│ │ ┌─────────────┐ │
│ │ │ Top-10 │ │
│ │ │ Results │ │
│ │ └─────────────┘ │
│ │ │
│ Quality vs Speed: │ │
│ • Bi-Encoder: Fast (O(1) via index), Lower Quality │
│ • Cross-Encoder: Slow (O(n) pairs), Higher Quality │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Cross-Encoder vs Bi-Encoder
┌─────────────────────────────────────────────────────────────────────────────┐
│ Bi-Encoder vs Cross-Encoder Architecture │
├────────────────────────────────┬────────────────────────────────────────────┤
│ Bi-Encoder │ Cross-Encoder │
├────────────────────────────────┼────────────────────────────────────────────┤
│ │ │
│ ┌───────┐ ┌───────┐ │ ┌───────┐ ┌──────────┐ │
│ │ Query │ │ Doc │ │ │ Query │ │ Document │ │
│ └───┬───┘ └───┬───┘ │ └───┬───┘ └────┬─────┘ │
│ │ │ │ │ │ │
│ ▼ ▼ │ └─────┬─────┘ │
│ ┌───────┐ ┌───────┐ │ ▼ │
│ │Encoder│ │Encoder│ │ ┌──────────────────────┐ │
│ └───┬───┘ └───┬───┘ │ │[CLS] Query [SEP] Doc │ │
│ │ │ │ └──────────┬───────────┘ │
│ ▼ ▼ │ ▼ │
│ ┌───────┐ ┌───────┐ │ ┌──────────────────────┐ │
│ │Vector │ │Vector │ │ │ Encoder │ │
│ └───┬───┘ └───┬───┘ │ └──────────┬───────────┘ │
│ │ │ │ ▼ │
│ └──────┬──────┘ │ ┌──────────────────────┐ │
│ ▼ │ │ [CLS] Token │ │
│ ┌──────────────────┐ │ └──────────┬───────────┘ │
│ │ Cosine Similarity│ │ ▼ │
│ └──────────────────┘ │ ┌──────────────────────┐ │
│ │ │ Relevance Score │ │
│ Pre-compute doc vectors │ └──────────────────────┘ │
│ Compare at query time │ Joint query-doc encoding │
│ │ │
└────────────────────────────────┴────────────────────────────────────────────┘| Aspect | Bi-Encoder | Cross-Encoder |
|---|---|---|
| Speed | O(1) per query | O(n) per query |
| Pre-computation | Yes (index docs) | No |
| Accuracy | Good | Better |
| Use Case | Initial retrieval | Reranking |
Project Setup
# Create project directory
mkdir custom-reranker && cd custom-reranker
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install torch transformers datasets sentence-transformers
pip install accelerate wandb scikit-learn
pip install fastapi uvicornProject Structure
custom-reranker/
├── data/
│ ├── dataset.py # Dataset preparation
│ └── hard_negatives.py # Hard negative mining
├── models/
│ ├── cross_encoder.py # Cross-encoder model
│ └── losses.py # Ranking losses
├── training/
│ ├── trainer.py # Training loop
│ └── evaluation.py # Evaluation metrics
├── inference/
│ └── reranker.py # Inference utilities
├── api/
│ └── app.py # FastAPI application
├── scripts/
│ └── train.py # Training script
└── requirements.txtData Preparation
Training Data Format
Cross-encoder training requires query-document pairs with relevance labels:
# data/dataset.py
import torch
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import PreTrainedTokenizer
from typing import List, Dict, Any, Tuple
from dataclasses import dataclass
import random
@dataclass
class RerankerSample:
"""Single training sample."""
query: str
document: str
label: float # 0.0 to 1.0 relevance score
class RerankerDataset(Dataset):
"""Dataset for cross-encoder training."""
def __init__(
self,
samples: List[RerankerSample],
tokenizer: PreTrainedTokenizer,
max_length: int = 512,
):
self.samples = samples
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
sample = self.samples[idx]
# Encode query-document pair
encoding = self.tokenizer(
sample.query,
sample.document,
truncation="longest_first",
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(0),
"attention_mask": encoding["attention_mask"].squeeze(0),
"labels": torch.tensor(sample.label, dtype=torch.float),
}
class PairwiseDataset(Dataset):
"""Dataset for pairwise ranking training."""
def __init__(
self,
queries: List[str],
positives: List[str],
negatives: List[str],
tokenizer: PreTrainedTokenizer,
max_length: int = 512,
):
self.queries = queries
self.positives = positives
self.negatives = negatives
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.queries)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
query = self.queries[idx]
positive = self.positives[idx]
negative = self.negatives[idx]
# Encode positive pair
pos_encoding = self.tokenizer(
query,
positive,
truncation="longest_first",
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
# Encode negative pair
neg_encoding = self.tokenizer(
query,
negative,
truncation="longest_first",
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
return {
"pos_input_ids": pos_encoding["input_ids"].squeeze(0),
"pos_attention_mask": pos_encoding["attention_mask"].squeeze(0),
"neg_input_ids": neg_encoding["input_ids"].squeeze(0),
"neg_attention_mask": neg_encoding["attention_mask"].squeeze(0),
}
def load_ms_marco(
split: str = "train",
max_samples: int = None,
) -> Tuple[List[str], List[str], List[str]]:
"""Load MS MARCO dataset for reranker training."""
dataset = load_dataset("ms_marco", "v1.1", split=split)
queries = []
positives = []
negatives = []
for item in dataset:
if max_samples and len(queries) >= max_samples:
break
query = item["query"]
passages = item["passages"]
# Find positive and negative passages
pos_passages = [
p["passage_text"]
for p, is_selected in zip(passages["passage_text"], passages["is_selected"])
if is_selected == 1
]
neg_passages = [
p["passage_text"]
for p, is_selected in zip(passages["passage_text"], passages["is_selected"])
if is_selected == 0
]
if pos_passages and neg_passages:
queries.append(query)
positives.append(pos_passages[0])
negatives.append(random.choice(neg_passages))
return queries, positives, negativesHard Negative Mining
Hard negatives are challenging examples that improve model discrimination:
# data/hard_negatives.py
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Tuple
import faiss
from tqdm import tqdm
class HardNegativeMiner:
"""Mine hard negatives using dense retrieval."""
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
device: str = "cuda",
):
self.model = SentenceTransformer(model_name, device=device)
self.device = device
def build_index(
self,
documents: List[str],
batch_size: int = 64,
) -> faiss.Index:
"""Build FAISS index for documents."""
print("Encoding documents...")
embeddings = self.model.encode(
documents,
batch_size=batch_size,
show_progress_bar=True,
convert_to_numpy=True,
)
# Normalize for cosine similarity
faiss.normalize_L2(embeddings)
# Build index
dimension = embeddings.shape[1]
index = faiss.IndexFlatIP(dimension) # Inner product for cosine
index.add(embeddings)
return index
def mine_hard_negatives(
self,
queries: List[str],
documents: List[str],
positive_indices: List[int],
top_k: int = 100,
num_hard_negatives: int = 5,
batch_size: int = 64,
) -> List[List[int]]:
"""Mine hard negatives for each query."""
# Build document index
index = self.build_index(documents, batch_size)
# Encode queries
print("Encoding queries...")
query_embeddings = self.model.encode(
queries,
batch_size=batch_size,
show_progress_bar=True,
convert_to_numpy=True,
)
faiss.normalize_L2(query_embeddings)
# Search for similar documents
print("Searching for hard negatives...")
distances, indices = index.search(query_embeddings, top_k)
# Filter out positives and select hard negatives
hard_negatives = []
for i, (pos_idx, retrieved_indices) in enumerate(
zip(positive_indices, indices)
):
# Remove positive from candidates
neg_indices = [
idx for idx in retrieved_indices
if idx != pos_idx
][:num_hard_negatives]
hard_negatives.append(neg_indices)
return hard_negatives
class InBatchNegatives:
"""Generate in-batch negatives during training."""
@staticmethod
def sample_negatives(
batch_positives: List[str],
num_negatives: int = 3,
) -> List[List[str]]:
"""Sample negatives from other positives in batch."""
batch_size = len(batch_positives)
negatives = []
for i in range(batch_size):
# Use other positives as negatives
neg_indices = [j for j in range(batch_size) if j != i]
sampled = np.random.choice(
neg_indices,
size=min(num_negatives, len(neg_indices)),
replace=False,
)
negatives.append([batch_positives[j] for j in sampled])
return negativesCross-Encoder Model
# models/cross_encoder.py
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
from typing import Optional, Dict, Any, List
class CrossEncoder(nn.Module):
"""Cross-encoder for query-document scoring."""
def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
num_labels: int = 1,
dropout: float = 0.1,
):
super().__init__()
self.config = AutoConfig.from_pretrained(model_name)
self.encoder = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(dropout)
# Classification head
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
# Initialize weights
self._init_weights(self.classifier)
def _init_weights(self, module: nn.Module):
"""Initialize classifier weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Forward pass."""
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
# Use [CLS] token representation
cls_output = outputs.last_hidden_state[:, 0, :]
cls_output = self.dropout(cls_output)
# Get relevance scores
logits = self.classifier(cls_output).squeeze(-1)
loss = None
if labels is not None:
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, labels)
return {
"loss": loss,
"logits": logits,
"scores": torch.sigmoid(logits),
}
@torch.no_grad()
def predict(
self,
queries: List[str],
documents: List[str],
tokenizer: AutoTokenizer,
batch_size: int = 32,
max_length: int = 512,
) -> List[float]:
"""Predict relevance scores."""
self.eval()
scores = []
for i in range(0, len(queries), batch_size):
batch_queries = queries[i:i + batch_size]
batch_docs = documents[i:i + batch_size]
# Tokenize
encoding = tokenizer(
batch_queries,
batch_docs,
truncation="longest_first",
max_length=max_length,
padding=True,
return_tensors="pt",
)
# Move to device
encoding = {k: v.to(next(self.parameters()).device) for k, v in encoding.items()}
# Forward pass
outputs = self(**encoding)
scores.extend(outputs["scores"].cpu().tolist())
return scores
class CrossEncoderForPairwise(CrossEncoder):
"""Cross-encoder with margin ranking loss."""
def forward_pairwise(
self,
pos_input_ids: torch.Tensor,
pos_attention_mask: torch.Tensor,
neg_input_ids: torch.Tensor,
neg_attention_mask: torch.Tensor,
margin: float = 1.0,
) -> Dict[str, torch.Tensor]:
"""Forward pass with pairwise loss."""
# Score positive pairs
pos_outputs = self.encoder(
input_ids=pos_input_ids,
attention_mask=pos_attention_mask,
)
pos_cls = pos_outputs.last_hidden_state[:, 0, :]
pos_scores = self.classifier(self.dropout(pos_cls)).squeeze(-1)
# Score negative pairs
neg_outputs = self.encoder(
input_ids=neg_input_ids,
attention_mask=neg_attention_mask,
)
neg_cls = neg_outputs.last_hidden_state[:, 0, :]
neg_scores = self.classifier(self.dropout(neg_cls)).squeeze(-1)
# Margin ranking loss
loss_fn = nn.MarginRankingLoss(margin=margin)
target = torch.ones_like(pos_scores) # pos > neg
loss = loss_fn(pos_scores, neg_scores, target)
return {
"loss": loss,
"pos_scores": pos_scores,
"neg_scores": neg_scores,
}Ranking Losses
# models/losses.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class MarginRankingLoss(nn.Module):
"""Margin ranking loss for pairwise training."""
def __init__(self, margin: float = 1.0):
super().__init__()
self.margin = margin
def forward(
self,
pos_scores: torch.Tensor,
neg_scores: torch.Tensor,
) -> torch.Tensor:
"""
Args:
pos_scores: Scores for positive pairs [batch_size]
neg_scores: Scores for negative pairs [batch_size]
"""
# loss = max(0, margin - (pos - neg))
loss = F.relu(self.margin - pos_scores + neg_scores)
return loss.mean()
class ListwiseLoss(nn.Module):
"""Listwise ranking loss using ListMLE."""
def __init__(self, eps: float = 1e-10):
super().__init__()
self.eps = eps
def forward(
self,
scores: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
ListMLE loss.
Args:
scores: Predicted scores [batch_size, num_docs]
labels: Relevance labels [batch_size, num_docs]
"""
# Sort by relevance labels (descending)
sorted_indices = labels.argsort(dim=-1, descending=True)
sorted_scores = scores.gather(-1, sorted_indices)
# Compute ListMLE
max_scores = sorted_scores.max(dim=-1, keepdim=True).values
exp_scores = torch.exp(sorted_scores - max_scores)
cumsum_exp = torch.cumsum(exp_scores.flip(dims=[-1]), dim=-1).flip(dims=[-1])
loss = -torch.sum(
sorted_scores - max_scores - torch.log(cumsum_exp + self.eps),
dim=-1,
)
return loss.mean()
class ContrastiveLoss(nn.Module):
"""Contrastive loss with temperature scaling."""
def __init__(self, temperature: float = 0.05):
super().__init__()
self.temperature = temperature
def forward(
self,
pos_scores: torch.Tensor,
neg_scores: torch.Tensor,
) -> torch.Tensor:
"""
InfoNCE-style contrastive loss.
Args:
pos_scores: Scores for positive pairs [batch_size]
neg_scores: Scores for negative pairs [batch_size, num_negatives]
"""
# Scale by temperature
pos_scores = pos_scores / self.temperature
neg_scores = neg_scores / self.temperature
# Combine positive and negatives
if neg_scores.dim() == 1:
neg_scores = neg_scores.unsqueeze(-1)
# log(exp(pos) / (exp(pos) + sum(exp(neg))))
all_scores = torch.cat([pos_scores.unsqueeze(-1), neg_scores], dim=-1)
labels = torch.zeros(pos_scores.size(0), dtype=torch.long, device=pos_scores.device)
loss = F.cross_entropy(all_scores, labels)
return loss
class FocalLoss(nn.Module):
"""Focal loss for handling class imbalance."""
def __init__(self, gamma: float = 2.0, alpha: float = 0.25):
super().__init__()
self.gamma = gamma
self.alpha = alpha
def forward(
self,
logits: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Args:
logits: Predicted logits [batch_size]
labels: Binary labels [batch_size]
"""
probs = torch.sigmoid(logits)
ce_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction="none")
# Focal term
p_t = probs * labels + (1 - probs) * (1 - labels)
focal_weight = (1 - p_t) ** self.gamma
# Alpha balancing
alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
loss = alpha_t * focal_weight * ce_loss
return loss.mean()Training Implementation
# training/trainer.py
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from typing import Dict, Any, Optional
import wandb
from tqdm import tqdm
class RerankerTrainer:
"""Trainer for cross-encoder reranker."""
def __init__(
self,
model,
tokenizer: AutoTokenizer,
train_dataset,
eval_dataset,
output_dir: str = "./outputs",
learning_rate: float = 2e-5,
batch_size: int = 16,
num_epochs: int = 3,
warmup_ratio: float = 0.1,
weight_decay: float = 0.01,
device: str = "cuda",
):
self.model = model.to(device)
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.output_dir = output_dir
self.learning_rate = learning_rate
self.batch_size = batch_size
self.num_epochs = num_epochs
self.warmup_ratio = warmup_ratio
self.weight_decay = weight_decay
self.device = device
def train(self) -> Dict[str, Any]:
"""Run training."""
# Create dataloaders
train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=4,
)
eval_loader = DataLoader(
self.eval_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
)
# Setup optimizer
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
# Setup scheduler
total_steps = len(train_loader) * self.num_epochs
warmup_steps = int(total_steps * self.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
# Training loop
best_eval_loss = float("inf")
global_step = 0
for epoch in range(self.num_epochs):
self.model.train()
epoch_loss = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{self.num_epochs}")
for batch in pbar:
# Move to device
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass
outputs = self.model(**batch)
loss = outputs["loss"]
# Backward pass
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
global_step += 1
pbar.set_postfix({"loss": loss.item()})
# Log to wandb
if global_step % 100 == 0:
wandb.log({
"train/loss": loss.item(),
"train/lr": scheduler.get_last_lr()[0],
"step": global_step,
})
# Evaluation
eval_loss = self._evaluate(eval_loader)
print(f"Epoch {epoch + 1} - Train Loss: {epoch_loss / len(train_loader):.4f}, Eval Loss: {eval_loss:.4f}")
wandb.log({
"eval/loss": eval_loss,
"epoch": epoch + 1,
})
# Save best model
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
self._save_model(f"{self.output_dir}/best")
# Save final model
self._save_model(f"{self.output_dir}/final")
return {
"best_eval_loss": best_eval_loss,
"final_step": global_step,
}
def _evaluate(self, eval_loader: DataLoader) -> float:
"""Evaluate model."""
self.model.eval()
total_loss = 0
with torch.no_grad():
for batch in eval_loader:
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(**batch)
total_loss += outputs["loss"].item()
return total_loss / len(eval_loader)
def _save_model(self, path: str):
"""Save model checkpoint."""
torch.save({
"model_state_dict": self.model.state_dict(),
"tokenizer": self.tokenizer,
}, f"{path}/model.pt")
self.tokenizer.save_pretrained(path)
class PairwiseTrainer(RerankerTrainer):
"""Trainer for pairwise ranking."""
def train(self) -> Dict[str, Any]:
"""Run pairwise training."""
train_loader = DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=4,
)
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay,
)
total_steps = len(train_loader) * self.num_epochs
warmup_steps = int(total_steps * self.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps,
)
global_step = 0
for epoch in range(self.num_epochs):
self.model.train()
epoch_loss = 0
pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
for batch in pbar:
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward with pairwise loss
outputs = self.model.forward_pairwise(**batch)
loss = outputs["loss"]
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
global_step += 1
# Log accuracy
acc = (outputs["pos_scores"] > outputs["neg_scores"]).float().mean()
pbar.set_postfix({"loss": loss.item(), "acc": acc.item()})
self._save_model(f"{self.output_dir}/final")
return {"final_step": global_step}Evaluation Metrics
# training/evaluation.py
import numpy as np
from typing import List, Dict, Any
from sklearn.metrics import ndcg_score
class RankingMetrics:
"""Compute ranking evaluation metrics."""
@staticmethod
def mean_reciprocal_rank(
labels: List[List[int]],
predictions: List[List[float]],
) -> float:
"""
Mean Reciprocal Rank (MRR).
Args:
labels: Binary relevance labels [num_queries, num_docs]
predictions: Predicted scores [num_queries, num_docs]
"""
mrr_scores = []
for query_labels, query_preds in zip(labels, predictions):
# Sort by predicted score
sorted_indices = np.argsort(query_preds)[::-1]
sorted_labels = np.array(query_labels)[sorted_indices]
# Find rank of first relevant document
relevant_positions = np.where(sorted_labels == 1)[0]
if len(relevant_positions) > 0:
first_rank = relevant_positions[0] + 1
mrr_scores.append(1.0 / first_rank)
else:
mrr_scores.append(0.0)
return np.mean(mrr_scores)
@staticmethod
def ndcg_at_k(
labels: List[List[int]],
predictions: List[List[float]],
k: int = 10,
) -> float:
"""
Normalized Discounted Cumulative Gain at k.
Args:
labels: Relevance labels [num_queries, num_docs]
predictions: Predicted scores [num_queries, num_docs]
k: Cutoff position
"""
ndcg_scores = []
for query_labels, query_preds in zip(labels, predictions):
# Truncate to k
query_labels = np.array(query_labels)
query_preds = np.array(query_preds)
if len(query_labels) > k:
# Get top-k by prediction
top_k_indices = np.argsort(query_preds)[-k:][::-1]
query_labels = query_labels[top_k_indices]
query_preds = query_preds[top_k_indices]
if query_labels.sum() > 0:
score = ndcg_score([query_labels], [query_preds], k=k)
ndcg_scores.append(score)
return np.mean(ndcg_scores) if ndcg_scores else 0.0
@staticmethod
def precision_at_k(
labels: List[List[int]],
predictions: List[List[float]],
k: int = 10,
) -> float:
"""
Precision at k.
Args:
labels: Binary relevance labels
predictions: Predicted scores
k: Cutoff position
"""
precisions = []
for query_labels, query_preds in zip(labels, predictions):
# Sort by predicted score
sorted_indices = np.argsort(query_preds)[::-1][:k]
top_k_labels = np.array(query_labels)[sorted_indices]
precision = top_k_labels.sum() / k
precisions.append(precision)
return np.mean(precisions)
@staticmethod
def mean_average_precision(
labels: List[List[int]],
predictions: List[List[float]],
) -> float:
"""
Mean Average Precision (MAP).
Args:
labels: Binary relevance labels
predictions: Predicted scores
"""
ap_scores = []
for query_labels, query_preds in zip(labels, predictions):
sorted_indices = np.argsort(query_preds)[::-1]
sorted_labels = np.array(query_labels)[sorted_indices]
# Calculate average precision
num_relevant = sorted_labels.sum()
if num_relevant == 0:
ap_scores.append(0.0)
continue
precisions = []
relevant_count = 0
for i, label in enumerate(sorted_labels):
if label == 1:
relevant_count += 1
precisions.append(relevant_count / (i + 1))
ap_scores.append(np.mean(precisions))
return np.mean(ap_scores)
def evaluate_reranker(
model,
tokenizer,
eval_queries: List[str],
eval_documents: List[List[str]],
eval_labels: List[List[int]],
) -> Dict[str, float]:
"""Comprehensive reranker evaluation."""
all_predictions = []
for query, docs in zip(eval_queries, eval_documents):
# Score all documents for this query
queries_expanded = [query] * len(docs)
scores = model.predict(queries_expanded, docs, tokenizer)
all_predictions.append(scores)
metrics = RankingMetrics()
return {
"mrr": metrics.mean_reciprocal_rank(eval_labels, all_predictions),
"ndcg@10": metrics.ndcg_at_k(eval_labels, all_predictions, k=10),
"ndcg@5": metrics.ndcg_at_k(eval_labels, all_predictions, k=5),
"precision@10": metrics.precision_at_k(eval_labels, all_predictions, k=10),
"map": metrics.mean_average_precision(eval_labels, all_predictions),
}Inference and RAG Integration
# inference/reranker.py
import torch
from transformers import AutoTokenizer
from typing import List, Dict, Tuple, Any
import numpy as np
class Reranker:
"""Reranker for RAG systems."""
def __init__(
self,
model_path: str,
tokenizer_path: str = None,
device: str = "cuda",
max_length: int = 512,
):
self.device = device
self.max_length = max_length
# Load model
checkpoint = torch.load(f"{model_path}/model.pt", map_location=device)
self.model = checkpoint["model"]
self.model.to(device)
self.model.eval()
# Load tokenizer
tokenizer_path = tokenizer_path or model_path
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
@torch.no_grad()
def rerank(
self,
query: str,
documents: List[str],
top_k: int = None,
return_scores: bool = False,
) -> List[Tuple[int, str, float]]:
"""
Rerank documents for a query.
Args:
query: Search query
documents: List of candidate documents
top_k: Number of top documents to return
return_scores: Whether to include scores
Returns:
List of (original_index, document, score) tuples
"""
if not documents:
return []
# Score all documents
scores = self._score_batch(query, documents)
# Sort by score
sorted_indices = np.argsort(scores)[::-1]
if top_k:
sorted_indices = sorted_indices[:top_k]
results = []
for idx in sorted_indices:
if return_scores:
results.append((int(idx), documents[idx], float(scores[idx])))
else:
results.append((int(idx), documents[idx]))
return results
def _score_batch(
self,
query: str,
documents: List[str],
batch_size: int = 32,
) -> np.ndarray:
"""Score documents in batches."""
all_scores = []
for i in range(0, len(documents), batch_size):
batch_docs = documents[i:i + batch_size]
batch_queries = [query] * len(batch_docs)
encoding = self.tokenizer(
batch_queries,
batch_docs,
truncation="longest_first",
max_length=self.max_length,
padding=True,
return_tensors="pt",
).to(self.device)
outputs = self.model(**encoding)
scores = outputs["scores"].cpu().numpy()
all_scores.extend(scores)
return np.array(all_scores)
class CachedReranker(Reranker):
"""Reranker with caching for repeated queries."""
def __init__(self, *args, cache_size: int = 1000, **kwargs):
super().__init__(*args, **kwargs)
from functools import lru_cache
self._cache = {}
self.cache_size = cache_size
def _get_cache_key(self, query: str, doc: str) -> str:
"""Generate cache key."""
return f"{hash(query)}:{hash(doc)}"
def rerank(
self,
query: str,
documents: List[str],
top_k: int = None,
return_scores: bool = False,
) -> List[Tuple[int, str, float]]:
"""Rerank with caching."""
# Check cache
cached_scores = []
uncached_docs = []
uncached_indices = []
for i, doc in enumerate(documents):
key = self._get_cache_key(query, doc)
if key in self._cache:
cached_scores.append((i, self._cache[key]))
else:
uncached_docs.append(doc)
uncached_indices.append(i)
# Score uncached documents
if uncached_docs:
new_scores = self._score_batch(query, uncached_docs)
for idx, doc, score in zip(uncached_indices, uncached_docs, new_scores):
key = self._get_cache_key(query, doc)
self._cache[key] = score
cached_scores.append((idx, score))
# Evict if cache is full
if len(self._cache) > self.cache_size:
oldest_key = next(iter(self._cache))
del self._cache[oldest_key]
# Sort by score
cached_scores.sort(key=lambda x: x[1], reverse=True)
if top_k:
cached_scores = cached_scores[:top_k]
results = []
for idx, score in cached_scores:
if return_scores:
results.append((idx, documents[idx], score))
else:
results.append((idx, documents[idx]))
return resultsFastAPI Application
# api/app.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List, Optional
import torch
from contextlib import asynccontextmanager
from inference.reranker import Reranker
# Global reranker instance
reranker: Optional[Reranker] = None
class RerankRequest(BaseModel):
"""Rerank request."""
query: str = Field(..., description="Search query")
documents: List[str] = Field(..., description="Documents to rerank")
top_k: Optional[int] = Field(None, description="Number of results")
return_scores: bool = Field(True, description="Include scores")
class RerankResult(BaseModel):
"""Single rerank result."""
index: int
document: str
score: Optional[float] = None
class RerankResponse(BaseModel):
"""Rerank response."""
results: List[RerankResult]
query: str
class BatchRerankRequest(BaseModel):
"""Batch rerank request."""
queries: List[str]
documents: List[List[str]]
top_k: Optional[int] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
global reranker
print("Loading reranker model...")
reranker = Reranker(
model_path="./outputs/best",
device="cuda" if torch.cuda.is_available() else "cpu",
)
print("Model loaded!")
yield
# Cleanup
del reranker
app = FastAPI(
title="Cross-Encoder Reranker API",
description="Rerank documents using a trained cross-encoder",
lifespan=lifespan,
)
@app.post("/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
"""Rerank documents for a query."""
if not reranker:
raise HTTPException(status_code=503, detail="Model not loaded")
if not request.documents:
raise HTTPException(status_code=400, detail="No documents provided")
results = reranker.rerank(
query=request.query,
documents=request.documents,
top_k=request.top_k,
return_scores=request.return_scores,
)
return RerankResponse(
query=request.query,
results=[
RerankResult(
index=r[0],
document=r[1],
score=r[2] if request.return_scores else None,
)
for r in results
],
)
@app.post("/rerank/batch")
async def batch_rerank(request: BatchRerankRequest):
"""Batch rerank multiple queries."""
if not reranker:
raise HTTPException(status_code=503, detail="Model not loaded")
if len(request.queries) != len(request.documents):
raise HTTPException(
status_code=400,
detail="Queries and documents must have same length",
)
all_results = []
for query, docs in zip(request.queries, request.documents):
results = reranker.rerank(
query=query,
documents=docs,
top_k=request.top_k,
return_scores=True,
)
all_results.append({
"query": query,
"results": [
{"index": r[0], "document": r[1], "score": r[2]}
for r in results
],
})
return {"results": all_results}
@app.get("/health")
async def health():
"""Health check."""
return {
"status": "healthy",
"model_loaded": reranker is not None,
}Training Script
# scripts/train.py
import argparse
import wandb
import torch
from transformers import AutoTokenizer
from data.dataset import RerankerDataset, PairwiseDataset, load_ms_marco, RerankerSample
from data.hard_negatives import HardNegativeMiner
from models.cross_encoder import CrossEncoder, CrossEncoderForPairwise
from training.trainer import RerankerTrainer, PairwiseTrainer
from training.evaluation import evaluate_reranker
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="cross-encoder/ms-marco-MiniLM-L-6-v2")
parser.add_argument("--output-dir", default="./outputs")
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--learning-rate", type=float, default=2e-5)
parser.add_argument("--max-samples", type=int, default=100000)
parser.add_argument("--use-pairwise", action="store_true")
parser.add_argument("--mine-hard-negatives", action="store_true")
parser.add_argument("--wandb-project", default="custom-reranker")
args = parser.parse_args()
# Initialize wandb
wandb.init(project=args.wandb_project, config=vars(args))
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model)
# Load data
print("Loading MS MARCO data...")
queries, positives, negatives = load_ms_marco(
split="train",
max_samples=args.max_samples,
)
# Optionally mine hard negatives
if args.mine_hard_negatives:
print("Mining hard negatives...")
miner = HardNegativeMiner()
all_docs = list(set(positives + negatives))
pos_indices = [all_docs.index(p) for p in positives]
hard_neg_indices = miner.mine_hard_negatives(
queries=queries,
documents=all_docs,
positive_indices=pos_indices,
num_hard_negatives=3,
)
# Use hard negatives
negatives = [all_docs[idx[0]] for idx in hard_neg_indices]
# Split into train/eval
split_idx = int(len(queries) * 0.9)
train_queries, eval_queries = queries[:split_idx], queries[split_idx:]
train_pos, eval_pos = positives[:split_idx], positives[split_idx:]
train_neg, eval_neg = negatives[:split_idx], negatives[split_idx:]
# Create datasets
if args.use_pairwise:
train_dataset = PairwiseDataset(
queries=train_queries,
positives=train_pos,
negatives=train_neg,
tokenizer=tokenizer,
)
model = CrossEncoderForPairwise(args.model)
TrainerClass = PairwiseTrainer
else:
# Pointwise dataset
train_samples = []
for q, p, n in zip(train_queries, train_pos, train_neg):
train_samples.append(RerankerSample(q, p, 1.0))
train_samples.append(RerankerSample(q, n, 0.0))
train_dataset = RerankerDataset(train_samples, tokenizer)
eval_samples = []
for q, p, n in zip(eval_queries, eval_pos, eval_neg):
eval_samples.append(RerankerSample(q, p, 1.0))
eval_samples.append(RerankerSample(q, n, 0.0))
eval_dataset = RerankerDataset(eval_samples, tokenizer)
model = CrossEncoder(args.model)
TrainerClass = RerankerTrainer
# Train
trainer = TrainerClass(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset if not args.use_pairwise else None,
output_dir=args.output_dir,
learning_rate=args.learning_rate,
batch_size=args.batch_size,
num_epochs=args.epochs,
)
results = trainer.train()
print(f"Training complete: {results}")
# Final evaluation
print("Running final evaluation...")
eval_docs = [[p, n] for p, n in zip(eval_pos, eval_neg)]
eval_labels = [[1, 0] for _ in eval_pos]
metrics = evaluate_reranker(
model=model,
tokenizer=tokenizer,
eval_queries=eval_queries,
eval_documents=eval_docs,
eval_labels=eval_labels,
)
print(f"Evaluation metrics: {metrics}")
wandb.log(metrics)
if __name__ == "__main__":
main()Integration Example
# Example: Integrate with RAG pipeline
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from inference.reranker import Reranker
class RerankedRetriever:
"""RAG retriever with cross-encoder reranking."""
def __init__(
self,
vectorstore: Chroma,
reranker: Reranker,
initial_k: int = 100,
final_k: int = 10,
):
self.vectorstore = vectorstore
self.reranker = reranker
self.initial_k = initial_k
self.final_k = final_k
def retrieve(self, query: str) -> list:
"""Two-stage retrieval with reranking."""
# Stage 1: Dense retrieval
initial_docs = self.vectorstore.similarity_search(
query,
k=self.initial_k,
)
# Stage 2: Reranking
doc_texts = [doc.page_content for doc in initial_docs]
reranked = self.reranker.rerank(
query=query,
documents=doc_texts,
top_k=self.final_k,
return_scores=True,
)
# Return reranked documents
return [
{
"content": initial_docs[idx].page_content,
"metadata": initial_docs[idx].metadata,
"score": score,
}
for idx, _, score in reranked
]
# Usage
vectorstore = Chroma(embedding_function=OpenAIEmbeddings())
reranker = Reranker(model_path="./outputs/best")
retriever = RerankedRetriever(
vectorstore=vectorstore,
reranker=reranker,
initial_k=100,
final_k=10,
)
results = retriever.retrieve("What is machine learning?")Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| Cross-Encoder | Model that jointly encodes query+document pairs | Higher accuracy than bi-encoders through deep interaction |
| Bi-Encoder | Separate encoders for query and documents | Fast retrieval via pre-computed document embeddings |
| Two-Stage Retrieval | Bi-encoder retrieves top-k, cross-encoder reranks | Balances speed (O(1) initial) with quality (O(n) rerank) |
| Hard Negative Mining | Finding challenging negative examples via dense retrieval | Improves model discrimination on difficult cases |
| Margin Ranking Loss | Loss: max(0, margin - (pos_score - neg_score)) | Ensures positive scored higher than negative by margin |
| Contrastive Loss (InfoNCE) | Softmax over positive vs all negatives with temperature | Pushes positives together, negatives apart in score space |
| MRR (Mean Reciprocal Rank) | Average of 1/rank of first relevant document | Measures how quickly users find relevant results |
| NDCG@k | Normalized discounted cumulative gain at k | Rewards relevant docs ranked higher, penalizes lower |
| [CLS] Token | Special token whose representation is used for scoring | Aggregates query-document pair information |
Next Steps
After completing this project, consider:
- Knowledge Distillation - Compress your reranker
- LoRA Fine-tuning - Efficient model adaptation
- RAG with Reranking - Full RAG integration