Custom Text Classifier
Build a text classifier from scratch with PyTorch
Custom Text Classifier
TL;DR
Build a text classifier from scratch using PyTorch's nn.Module. Learn the core training loop pattern (forward pass → loss → backward pass → optimizer step), bidirectional LSTM for sequence modeling, and vocabulary building with tokenization.
Build a complete text classification model using PyTorch's nn.Module from scratch.
What You'll Learn
- PyTorch tensor operations and autograd
- Building neural networks with nn.Module
- Training loops and backpropagation
- Model evaluation and metrics
- Saving and loading models
Tech Stack
| Component | Technology |
|---|---|
| Framework | PyTorch |
| Tokenization | torchtext / custom |
| API | FastAPI |
| Data | pandas, scikit-learn |
Architecture
┌──────────────────────────────────────────────────────────────────────────────┐
│ TEXT CLASSIFICATION PIPELINE │
├──────────────────────────────────────────────────────────────────────────────┤
│ │
│ INFERENCE PIPELINE │
│ ┌──────────┐ ┌───────────┐ ┌───────────┐ ┌────────┐ ┌───────────┐ │
│ │ Raw Text │──▶│ Tokenizer │──▶│ Embedding │──▶│ LSTM │──▶│ Softmax │ │
│ └──────────┘ └───────────┘ │ Layer │ │ Layers │ └─────┬─────┘ │
│ └───────────┘ └────────┘ │ │
│ ▼ │
│ ┌────────────┐ │
│ │ Prediction │ │
│ └────────────┘ │
├──────────────────────────────────────────────────────────────────────────────┤
│ │
│ TRAINING LOOP │
│ ┌─────────┐ ┌────────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ Dataset │──▶│ DataLoader │──▶│ Forward │──▶│ Loss │──▶│ Backward │ │
│ └─────────┘ │ (batch) │ │ Pass │ │ (CE) │ │ Pass │ │
│ └────────────┘ └──────────┘ └──────────┘ └────┬─────┘ │
│ │ │
│ ┌────────────────────────────────────────────┐ │ │
│ │ Optimizer Step (AdamW) │◀┘ │
│ │ Update weights using gradients │ │
│ └────────────────────────────────────────────┘ │
│ │
└──────────────────────────────────────────────────────────────────────────────┘Project Structure
text-classifier/
├── src/
│ ├── __init__.py
│ ├── model.py # Neural network definition
│ ├── dataset.py # Custom dataset class
│ ├── tokenizer.py # Text tokenization
│ ├── train.py # Training loop
│ └── predict.py # Inference
├── api/
│ └── main.py # FastAPI application
├── data/
│ └── sample_data.csv
├── models/ # Saved model weights
├── tests/
│ └── test_model.py
├── requirements.txt
└── DockerfileImplementation
Step 1: Dependencies
torch>=2.0.0
torchtext>=0.15.0
pandas>=2.0.0
scikit-learn>=1.3.0
fastapi>=0.100.0
uvicorn>=0.23.0Step 2: Vocabulary and Tokenizer
"""Simple tokenizer for text classification."""
import re
from collections import Counter
from typing import List, Dict, Optional
class SimpleTokenizer:
"""
A simple word-level tokenizer with vocabulary building.
This tokenizer:
- Lowercases text
- Splits on whitespace and punctuation
- Builds vocabulary from training data
- Handles unknown tokens with <UNK>
"""
def __init__(
self,
max_vocab_size: int = 10000,
min_freq: int = 2
):
self.max_vocab_size = max_vocab_size
self.min_freq = min_freq
# Special tokens
self.pad_token = "<PAD>"
self.unk_token = "<UNK>"
# Vocabulary mappings
self.word2idx: Dict[str, int] = {}
self.idx2word: Dict[int, str] = {}
self.vocab_size = 0
def _tokenize(self, text: str) -> List[str]:
"""Split text into tokens."""
# Lowercase and split on non-alphanumeric
text = text.lower()
tokens = re.findall(r'\b\w+\b', text)
return tokens
def build_vocab(self, texts: List[str]) -> None:
"""Build vocabulary from list of texts."""
# Count all tokens
counter = Counter()
for text in texts:
tokens = self._tokenize(text)
counter.update(tokens)
# Start with special tokens
self.word2idx = {
self.pad_token: 0,
self.unk_token: 1
}
# Add most common tokens
idx = 2
for word, freq in counter.most_common(self.max_vocab_size - 2):
if freq >= self.min_freq:
self.word2idx[word] = idx
idx += 1
# Create reverse mapping
self.idx2word = {v: k for k, v in self.word2idx.items()}
self.vocab_size = len(self.word2idx)
print(f"Vocabulary built: {self.vocab_size} tokens")
def encode(
self,
text: str,
max_length: Optional[int] = None
) -> List[int]:
"""Convert text to token indices."""
tokens = self._tokenize(text)
# Convert to indices
indices = [
self.word2idx.get(token, self.word2idx[self.unk_token])
for token in tokens
]
# Truncate or pad
if max_length:
if len(indices) > max_length:
indices = indices[:max_length]
else:
indices = indices + [0] * (max_length - len(indices))
return indices
def decode(self, indices: List[int]) -> str:
"""Convert token indices back to text."""
tokens = [
self.idx2word.get(idx, self.unk_token)
for idx in indices
if idx != 0 # Skip padding
]
return " ".join(tokens)
def save(self, path: str) -> None:
"""Save vocabulary to file."""
import json
with open(path, 'w') as f:
json.dump({
'word2idx': self.word2idx,
'max_vocab_size': self.max_vocab_size,
'min_freq': self.min_freq
}, f)
@classmethod
def load(cls, path: str) -> 'SimpleTokenizer':
"""Load vocabulary from file."""
import json
with open(path, 'r') as f:
data = json.load(f)
tokenizer = cls(
max_vocab_size=data['max_vocab_size'],
min_freq=data['min_freq']
)
tokenizer.word2idx = data['word2idx']
tokenizer.idx2word = {int(v): k for k, v in data['word2idx'].items()}
tokenizer.vocab_size = len(tokenizer.word2idx)
return tokenizerUnderstanding Vocabulary Building:
┌─────────────────────────────────────────────────────────────────────────────┐
│ WHY BUILD A VOCABULARY │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Neural networks need numbers, not words: │
│ │
│ Text: "Python is great" │
│ │ │
│ ▼ │
│ Tokenize: ["python", "is", "great"] │
│ │ │
│ ▼ │
│ Vocabulary: {"<PAD>": 0, "<UNK>": 1, "python": 2, "is": 3, "great": 4} │
│ │ │
│ ▼ │
│ Indices: [2, 3, 4] ──► Ready for embedding layer! │
│ │
│ Special Tokens: │
│ ┌──────────┬─────────────────────────────────────────────────────────┐ │
│ │ <PAD>=0 │ Padding for fixed-length sequences │ │
│ │ <UNK>=1 │ Unknown words not in vocabulary │ │
│ └──────────┴─────────────────────────────────────────────────────────┘ │
│ │
│ Frequency Cutoff (min_freq=2): │
│ Words appearing only once are likely typos or rare terms │
│ → Replace with <UNK> to reduce vocabulary size and overfitting │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Key Design Decisions:
| Decision | Reasoning |
|---|---|
max_vocab_size=10000 | Limits memory, covers most common words |
min_freq=2 | Filters out rare words/typos |
| Lowercase everything | "Python" and "python" should be same token |
| Regex tokenization | Handles punctuation consistently |
Step 3: Custom Dataset
"""PyTorch Dataset for text classification."""
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Optional
import pandas as pd
from .tokenizer import SimpleTokenizer
class TextClassificationDataset(Dataset):
"""
Custom Dataset for text classification.
Handles:
- Text tokenization and encoding
- Label encoding
- Sequence padding
"""
def __init__(
self,
texts: List[str],
labels: List[int],
tokenizer: SimpleTokenizer,
max_length: int = 128
):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.texts)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
text = self.texts[idx]
label = self.labels[idx]
# Encode text to indices
encoded = self.tokenizer.encode(text, max_length=self.max_length)
# Convert to tensors
input_ids = torch.tensor(encoded, dtype=torch.long)
label_tensor = torch.tensor(label, dtype=torch.long)
return input_ids, label_tensor
def create_dataloaders(
train_texts: List[str],
train_labels: List[int],
val_texts: List[str],
val_labels: List[int],
tokenizer: SimpleTokenizer,
batch_size: int = 32,
max_length: int = 128
) -> Tuple[DataLoader, DataLoader]:
"""Create train and validation DataLoaders."""
train_dataset = TextClassificationDataset(
train_texts, train_labels, tokenizer, max_length
)
val_dataset = TextClassificationDataset(
val_texts, val_labels, tokenizer, max_length
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True
)
return train_loader, val_loader
def load_data(filepath: str) -> Tuple[List[str], List[int]]:
"""Load data from CSV file."""
df = pd.read_csv(filepath)
# Assume columns: 'text' and 'label'
texts = df['text'].tolist()
labels = df['label'].tolist()
return texts, labelsStep 4: Neural Network Model
"""Text classification neural network."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class TextClassifier(nn.Module):
"""
A neural network for text classification.
Architecture:
- Embedding layer: Converts token indices to dense vectors
- LSTM/GRU: Captures sequential patterns
- Fully connected layers: Classification head
Args:
vocab_size: Size of the vocabulary
embedding_dim: Dimension of embeddings
hidden_dim: Hidden dimension of LSTM
num_classes: Number of output classes
num_layers: Number of LSTM layers
dropout: Dropout probability
bidirectional: Use bidirectional LSTM
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int = 128,
hidden_dim: int = 256,
num_classes: int = 2,
num_layers: int = 2,
dropout: float = 0.3,
bidirectional: bool = True,
padding_idx: int = 0
):
super().__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.bidirectional = bidirectional
# Embedding layer
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=padding_idx
)
# LSTM layer
self.lstm = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional
)
# Calculate output dimension
lstm_output_dim = hidden_dim * 2 if bidirectional else hidden_dim
# Classification head
self.classifier = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(lstm_output_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, num_classes)
)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Initialize model weights."""
for name, param in self.named_parameters():
if 'weight' in name and 'lstm' not in name:
nn.init.xavier_uniform_(param)
elif 'bias' in name:
nn.init.zeros_(param)
def forward(
self,
input_ids: torch.Tensor,
lengths: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Forward pass.
Args:
input_ids: Token indices [batch_size, seq_len]
lengths: Actual sequence lengths (optional)
Returns:
Logits [batch_size, num_classes]
"""
# Get embeddings [batch_size, seq_len, embedding_dim]
embedded = self.embedding(input_ids)
# Pass through LSTM
lstm_out, (hidden, cell) = self.lstm(embedded)
# Use the last hidden state for classification
if self.bidirectional:
# Concatenate forward and backward final hidden states
hidden_forward = hidden[-2, :, :]
hidden_backward = hidden[-1, :, :]
hidden_cat = torch.cat([hidden_forward, hidden_backward], dim=1)
else:
hidden_cat = hidden[-1, :, :]
# Classification
logits = self.classifier(hidden_cat)
return logits
def predict(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get predicted class labels."""
self.eval()
with torch.no_grad():
logits = self.forward(input_ids)
predictions = torch.argmax(logits, dim=1)
return predictions
def predict_proba(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Get class probabilities."""
self.eval()
with torch.no_grad():
logits = self.forward(input_ids)
probabilities = F.softmax(logits, dim=1)
return probabilitiesUnderstanding the LSTM Text Classifier:
┌─────────────────────────────────────────────────────────────────────────────┐
│ DATA FLOW THROUGH THE MODEL │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Input: [32, 15, 892, 4] (token indices, batch_size=1, seq_len=4) │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ EMBEDDING LAYER │ │
│ │ nn.Embedding(vocab_size=10000, embedding_dim=128) │ │
│ │ Each index → 128-dim vector │ │
│ │ Output: [1, 4, 128] │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ BIDIRECTIONAL LSTM │ │
│ │ Forward: [32] → [15] → [892] → [4] → hidden_fwd │ │
│ │ Backward: [32] ← [15] ← [892] ← [4] ← hidden_bwd │ │
│ │ │ │
│ │ Final hidden = concat(hidden_fwd, hidden_bwd) │ │
│ │ Output: [1, 512] (256 × 2 directions) │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ CLASSIFIER HEAD │ │
│ │ Dropout → Linear(512, 256) → ReLU → Dropout → Linear(256, 2) │ │
│ │ Output: [1, 2] (logits for 2 classes) │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Why Bidirectional LSTM:
| Approach | Context Used | Example Benefit |
|---|---|---|
| Forward-only | Left context | "The movie was ___" |
| Bidirectional | Left + Right | "The movie was ___ but the ending saved it" |
The word after "was" gets context from BOTH "movie" AND "but the ending saved it" - crucial for sentiment!
Layer Configuration Explained:
| Parameter | Value | Purpose |
|---|---|---|
hidden_dim=256 | LSTM hidden state size | Higher = more capacity, more memory |
num_layers=2 | Stacked LSTM layers | Deeper = better representations |
dropout=0.3 | 30% dropout probability | Prevents overfitting |
padding_idx=0 | PAD token gets zero embedding | Don't learn from padding |
Step 5: Training Loop
"""Training script for text classifier."""
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple, Optional
import time
from dataclasses import dataclass
@dataclass
class TrainingConfig:
"""Training configuration."""
epochs: int = 10
learning_rate: float = 1e-3
weight_decay: float = 0.01
max_grad_norm: float = 1.0
device: str = "cuda" if torch.cuda.is_available() else "cpu"
class Trainer:
"""Trainer class for text classification."""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: TrainingConfig
):
self.model = model.to(config.device)
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.device = config.device
self.criterion = nn.CrossEntropyLoss()
self.optimizer = AdamW(
model.parameters(),
lr=config.learning_rate,
weight_decay=config.weight_decay
)
total_steps = len(train_loader) * config.epochs
self.scheduler = OneCycleLR(
self.optimizer,
max_lr=config.learning_rate,
total_steps=total_steps
)
self.history: Dict[str, List[float]] = {
'train_loss': [],
'val_loss': [],
'val_accuracy': []
}
def train_epoch(self) -> float:
"""Train for one epoch."""
self.model.train()
total_loss = 0
for input_ids, labels in self.train_loader:
input_ids = input_ids.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
logits = self.model(input_ids)
loss = self.criterion(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item()
return total_loss / len(self.train_loader)
@torch.no_grad()
def evaluate(self) -> Tuple[float, float]:
"""Evaluate on validation set."""
self.model.eval()
total_loss = 0
correct = 0
total = 0
for input_ids, labels in self.val_loader:
input_ids = input_ids.to(self.device)
labels = labels.to(self.device)
logits = self.model(input_ids)
loss = self.criterion(logits, labels)
total_loss += loss.item()
predictions = torch.argmax(logits, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
return total_loss / len(self.val_loader), correct / total
def train(self) -> Dict[str, List[float]]:
"""Full training loop."""
best_accuracy = 0
for epoch in range(self.config.epochs):
train_loss = self.train_epoch()
val_loss, val_accuracy = self.evaluate()
self.history['train_loss'].append(train_loss)
self.history['val_loss'].append(val_loss)
self.history['val_accuracy'].append(val_accuracy)
print(f"Epoch {epoch + 1}/{self.config.epochs}")
print(f" Train Loss: {train_loss:.4f}")
print(f" Val Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}")
if val_accuracy > best_accuracy:
best_accuracy = val_accuracy
self.save_checkpoint('best_model.pt')
return self.history
def save_checkpoint(self, path: str) -> None:
"""Save model checkpoint."""
torch.save({
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'history': self.history
}, path)Understanding the Training Loop:
┌─────────────────────────────────────────────────────────────────────────────┐
│ THE FUNDAMENTAL TRAINING PATTERN │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ For each batch: │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 1. optimizer.zero_grad() Clear previous gradients │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ 2. logits = model(input_ids) Forward pass │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ 3. loss = criterion(logits, labels) Compute loss │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ 4. loss.backward() Compute gradients (backprop) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ 5. clip_grad_norm_() Prevent exploding gradients │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ 6. optimizer.step() Update weights: w = w - lr * grad │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ 7. scheduler.step() Adjust learning rate │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
│ Why OneCycleLR Scheduler: │
│ ┌─────────────────────────────────────────────────────────────────────┐ │
│ │ LR │ /\ │ │
│ │ │ / \ │ │
│ │ │ / \ │ │
│ │ │ / \____ │ │
│ │ │_____/ \ │ │
│ │ └─────────────────────► Epochs │ │
│ │ warmup peak decay │ │
│ │ │ │
│ │ • Warmup: Start low, increase (lets model settle) │ │
│ │ • Peak: Maximum learning rate (fast learning) │ │
│ │ • Decay: Decrease to fine-tune (converge smoothly) │ │
│ └─────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Training Configuration Choices:
| Setting | Default | Reasoning |
|---|---|---|
learning_rate=1e-3 | Standard for AdamW | Too high = diverge, too low = slow |
weight_decay=0.01 | L2 regularization | Prevents weights from growing too large |
max_grad_norm=1.0 | Gradient clipping | LSTMs prone to exploding gradients |
epochs=10 | Full passes over data | Monitor validation loss to stop early |
Step 6: FastAPI Application
"""FastAPI application for text classification."""
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import List
from src.model import TextClassifier
from src.tokenizer import SimpleTokenizer
class PredictionRequest(BaseModel):
text: str = Field(..., min_length=1)
class PredictionResponse(BaseModel):
label: int
confidence: float
probabilities: List[float]
model: TextClassifier = None
tokenizer: SimpleTokenizer = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
app = FastAPI(title="Text Classification API")
@app.on_event("startup")
async def startup():
global model, tokenizer
tokenizer = SimpleTokenizer.load("models/tokenizer.json")
model = TextClassifier(vocab_size=tokenizer.vocab_size)
checkpoint = torch.load("models/best_model.pt", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
@app.get("/health")
async def health():
return {"status": "healthy", "device": str(device)}
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
input_ids = tokenizer.encode(request.text, max_length=128)
input_tensor = torch.tensor([input_ids]).to(device)
with torch.no_grad():
logits = model(input_tensor)
probs = torch.softmax(logits, dim=1)[0].tolist()
return PredictionResponse(
label=probs.index(max(probs)),
confidence=max(probs),
probabilities=probs
)Running the Project
# Install dependencies
pip install -r requirements.txt
# Train the model
python -c "from src.train import train_model; train_model(...)"
# Run the API
uvicorn api.main:app --reload
# Test prediction
curl -X POST http://localhost:8000/predict \
-H "Content-Type: application/json" \
-d '{"text": "This is amazing!"}'Key Concepts
nn.Module
The base class for all PyTorch neural networks:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 5)
def forward(self, x):
return self.layer(x)Training Loop Pattern
for epoch in range(epochs):
model.train()
for batch in train_loader:
optimizer.zero_grad() # Clear gradients
outputs = model(inputs) # Forward pass
loss = criterion(outputs, labels)
loss.backward() # Backward pass
optimizer.step() # Update weightsGradient Clipping
Prevents exploding gradients:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| nn.Module | PyTorch base class for neural networks | Provides parameter tracking, device movement, and save/load |
| Embedding Layer | Maps token indices to dense vectors | Converts discrete words into learnable continuous representations |
| Bidirectional LSTM | LSTM that reads sequence forward and backward | Captures context from both directions for better representations |
| Training Loop | Forward → Loss → Backward → Optimizer step | The fundamental pattern for all neural network training |
| Gradient Clipping | Caps gradient magnitude during backprop | Prevents exploding gradients that destabilize training |
| CrossEntropyLoss | Combines LogSoftmax + NLLLoss | Standard loss for multi-class classification |
| OneCycleLR | Learning rate scheduler with warmup and decay | Faster convergence than constant learning rate |
| Vocabulary Building | Map words to indices with frequency cutoff | Handles out-of-vocabulary words via <UNK> token |
Next Steps
- Embedding Model - Create custom embeddings
- LoRA Fine-tuning - Efficient LLM fine-tuning