Multi-Modal Embeddings
Combine text and image embeddings with CLIP for cross-modal search
Multi-Modal Embeddings
Build systems that understand both text and images in the same embedding space
TL;DR
CLIP embeds both text and images into the same vector space. This lets you search images with text ("a sunset over the ocean" → finds sunset photos), find similar images, or even generate captions - all using cosine similarity.
What You'll Learn
- How CLIP aligns text and image embeddings
- Building cross-modal search (text-to-image, image-to-text)
- Multi-modal similarity and fusion
- Production deployment considerations
Tech Stack
| Component | Technology |
|---|---|
| Model | OpenCLIP / CLIP |
| Image Processing | PIL, torchvision |
| Vector Storage | ChromaDB |
| API | FastAPI |
How CLIP Works
┌─────────────────────────────────────────────────────────────────────────────┐
│ CLIP ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ INPUTS │ │ ENCODERS │ │
│ ├─────────────────┤ ├─────────────────┤ │
│ │ │ │ │ ┌─────────────────┐ │
│ │ Text: "A dog" │────────► │ Text Encoder │─────►│ │ │
│ │ │ │ (Transformer) │ │ SHARED VECTOR │ │
│ │ │ │ │ │ SPACE │ │
│ │ Image: [🐕] │────────► │ Image Encoder │─────►│ │ │
│ │ │ │ (ViT) │ │ cosine(t, i) │ │
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
│ │
│ Key Insight: Matching text-image pairs → similar vectors │
│ Non-matching pairs → distant vectors │
└─────────────────────────────────────────────────────────────────────────────┘CLIP (Contrastive Language-Image Pre-training) learns to align text and images in a shared embedding space. This enables:
- Text-to-Image: Find images matching a text description
- Image-to-Text: Find text describing an image
- Image-to-Image: Find visually similar images
Project Structure
multimodal-embeddings/
├── src/
│ ├── __init__.py
│ ├── clip_embeddings.py # CLIP model wrapper
│ ├── image_processor.py # Image preprocessing
│ ├── index.py # Vector storage
│ └── api.py # FastAPI application
├── data/
│ └── images/
├── app.py # Streamlit demo
├── requirements.txt
└── README.mdImplementation
Step 1: Setup
open-clip-torch>=2.20.0
torch>=2.0.0
torchvision>=0.15.0
Pillow>=10.0.0
chromadb>=0.4.0
fastapi>=0.100.0
uvicorn>=0.23.0
streamlit>=1.28.0
numpy>=1.24.0Step 2: CLIP Embeddings
"""
CLIP-based multi-modal embeddings.
"""
import torch
import open_clip
from PIL import Image
import numpy as np
from typing import Union
from pathlib import Path
class CLIPEmbedder:
"""
Generate embeddings for text and images using CLIP.
Both modalities are embedded into the same vector space,
enabling cross-modal similarity search.
"""
def __init__(
self,
model_name: str = "ViT-B-32",
pretrained: str = "laion2b_s34b_b79k"
):
"""
Initialize CLIP model.
Args:
model_name: CLIP model architecture
pretrained: Pretrained weights to use
Available models:
- ViT-B-32: Good balance of speed and quality
- ViT-L-14: Higher quality, slower
- ViT-H-14: Highest quality, slowest
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
model_name,
pretrained=pretrained,
device=self.device
)
self.tokenizer = open_clip.get_tokenizer(model_name)
self.embedding_dim = self.model.visual.output_dim
def embed_text(self, texts: Union[str, list[str]]) -> np.ndarray:
"""
Generate embeddings for text.
Args:
texts: Single text or list of texts
Returns:
Normalized embeddings
"""
if isinstance(texts, str):
texts = [texts]
with torch.no_grad():
tokens = self.tokenizer(texts).to(self.device)
embeddings = self.model.encode_text(tokens)
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings.cpu().numpy()
def embed_image(
self,
images: Union[Image.Image, list[Image.Image], str, list[str]]
) -> np.ndarray:
"""
Generate embeddings for images.
Args:
images: PIL Image(s) or path(s) to images
Returns:
Normalized embeddings
"""
# Handle paths
if isinstance(images, str):
images = [Image.open(images)]
elif isinstance(images, list) and isinstance(images[0], str):
images = [Image.open(p) for p in images]
elif isinstance(images, Image.Image):
images = [images]
# Preprocess
processed = torch.stack([
self.preprocess(img) for img in images
]).to(self.device)
with torch.no_grad():
embeddings = self.model.encode_image(processed)
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
return embeddings.cpu().numpy()
def similarity(
self,
text_embeddings: np.ndarray,
image_embeddings: np.ndarray
) -> np.ndarray:
"""
Compute similarity between text and image embeddings.
Args:
text_embeddings: Text embeddings (n_texts, dim)
image_embeddings: Image embeddings (n_images, dim)
Returns:
Similarity matrix (n_texts, n_images)
"""
return np.dot(text_embeddings, image_embeddings.T)What's Happening Here?
The CLIPEmbedder class is your gateway to multi-modal understanding. Let's break down each method:
┌─────────────────────────────────────────────────────────────────────────────┐
│ CLIPEmbedder Internal Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ __init__() │
│ ┌───────────────────────────────────────────────────────────────────────┐ │
│ │ 1. Check GPU availability (cuda vs cpu) │ │
│ │ 2. Load pretrained model + preprocessing transforms │ │
│ │ 3. Get tokenizer for text input │ │
│ │ 4. Store embedding dimension for downstream use │ │
│ └───────────────────────────────────────────────────────────────────────┘ │
│ │
│ embed_text(texts) │
│ ┌───────────────────────────────────────────────────────────────────────┐ │
│ │ "a photo of a cat" │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────┐ ┌───────────────┐ ┌──────────────┐ │ │
│ │ │ Tokenize │ ──► │ Text Encoder │ ──► │ Normalize │ ──► [0.1, ..│ │
│ │ │ (77 max) │ │ (Transformer) │ │ (unit length)│ 512-d] │ │
│ │ └──────────┘ └───────────────┘ └──────────────┘ │ │
│ └───────────────────────────────────────────────────────────────────────┘ │
│ │
│ embed_image(images) │
│ ┌───────────────────────────────────────────────────────────────────────┐ │
│ │ [Image File] │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────┐ ┌───────────────┐ ┌──────────────┐ │ │
│ │ │ Preproc │ ──► │ Vision Encoder│ ──► │ Normalize │ ──► [0.2, ..│ │
│ │ │ (224x224)│ │ (ViT) │ │ (unit length)│ 512-d] │ │
│ │ └──────────┘ └───────────────┘ └──────────────┘ │ │
│ └───────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Key Design Decisions Explained:
| Code Pattern | Why It's Done This Way |
|---|---|
torch.no_grad() | Disables gradient computation - we're doing inference, not training. Saves ~50% memory and speeds up computation. |
embeddings / embeddings.norm(dim=-1, keepdim=True) | L2 normalization to unit length. This converts dot product to cosine similarity, giving scores in [-1, 1] range. |
isinstance(images, str) checks | Flexible API - accepts file paths, PIL images, or lists of either. Makes the class easier to use in different contexts. |
laion2b_s34b_b79k pretrained weights | Trained on 2 billion image-text pairs from LAION. More diverse than original OpenAI CLIP weights. |
Understanding the Preprocessing:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Image Preprocessing Pipeline │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Original Image (any size) │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Resize to 224x224 │ ◄── CLIP was trained on this resolution │
│ │ (center crop) │ │
│ └─────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Convert to tensor │ ◄── [0, 255] integers → [0, 1] floats │
│ │ (H, W, C) → (C, H, W) │
│ └─────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ Normalize channels │ ◄── Subtract ImageNet mean, divide by std │
│ │ mean=[0.48, 0.46, │ This matches CLIP's training distribution │
│ │ 0.41] │ │
│ └─────────────────────┘ │
│ │ │
│ ▼ │
│ Ready for ViT encoder │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Deep Dive: How CLIP Was Trained
CLIP's magic comes from contrastive learning on 400M+ image-text pairs:
┌─────────────────────────────────────────────────────────────────────────────┐
│ CLIP Contrastive Training Process │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Training Batch (e.g., 32 image-text pairs): │
│ │
│ Images: [🐕] [🐱] [🌅] ... [🏔️] → Image Encoder → [I₁, I₂, ... I₃₂] │
│ Texts: "dog" "cat" "sunset" ... "mountain" → Text Encoder → [T₁, T₂, ..T₃₂]│
│ │
│ Similarity Matrix (32 x 32): │
│ │
│ T₁ T₂ T₃ ... T₃₂ │
│ ┌─────────────────────────────────┐ │
│ I₁ │ HIGH low low ... low │ ◄── I₁ matches T₁ (diagonal) │
│ I₂ │ low HIGH low ... low │ ◄── I₂ matches T₂ │
│ I₃ │ low low HIGH ... low │ │
│ ... │ ... ... ... ... ... │ │
│ I₃₂ │ low low low ... HIGH │ │
│ └─────────────────────────────────┘ │
│ │
│ Loss: Maximize diagonal (matching pairs), minimize off-diagonal │
│ │
│ Result: After training, semantically similar content → similar vectors │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Understanding CLIP Embeddings:
┌─────────────────────────────────────────────────────────────────────────────┐
│ WHY CLIP IS SPECIAL: Shared Embedding Space │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Traditional Approach (Separate Spaces): │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Text Space Image Space │ │
│ │ ● "dog" ★ 🐕 │ │
│ │ ● "cat" ★ 🐱 │ │
│ │ ● "sunset" ★ 🌅 │ │
│ │ │ │
│ │ Problem: Can't compare across modalities! │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │
│ CLIP Approach (Shared Space): │
│ ┌────────────────────────────────────────────────────────────────────┐ │
│ │ Unified Vector Space │ │
│ │ │ │
│ │ "a photo of a dog" ●────────● 🐕 image │ │
│ │ (close!) │ │
│ │ │ │
│ │ "sunset over ocean" ●────────● 🌅 image │ │
│ │ (close!) │ │
│ │ │ │
│ │ Now cosine similarity works across modalities! │ │
│ └────────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Model Selection:
| Model | Dimension | Speed | Quality | Use Case |
|---|---|---|---|---|
ViT-B-32 | 512 | Fast | Good | Prototyping, real-time |
ViT-B-16 | 512 | Medium | Better | Balanced |
ViT-L-14 | 768 | Slow | Best | High accuracy needs |
ViT-H-14 | 1024 | Slowest | Highest | Research |
Why Normalize Embeddings?
- CLIP embeddings are normalized to unit length
- This makes cosine similarity = dot product
- Faster computation and consistent score range (0 to 1)
class MultiModalIndex:
"""
Index for storing and searching multi-modal embeddings.
"""
def __init__(self, embedder: CLIPEmbedder):
import chromadb
self.embedder = embedder
self.client = chromadb.Client()
# Separate collections for images and texts
self.image_collection = self.client.get_or_create_collection(
name="images",
metadata={"hnsw:space": "cosine"}
)
self.text_collection = self.client.get_or_create_collection(
name="texts",
metadata={"hnsw:space": "cosine"}
)
def add_images(
self,
image_paths: list[str],
metadata: list[dict] = None
) -> None:
"""Add images to the index."""
embeddings = self.embedder.embed_image(image_paths)
ids = [f"img_{i}" for i in range(len(image_paths))]
meta = metadata or [{"path": p} for p in image_paths]
self.image_collection.add(
ids=ids,
embeddings=embeddings.tolist(),
metadatas=meta,
documents=image_paths
)
def add_texts(
self,
texts: list[str],
metadata: list[dict] = None
) -> None:
"""Add texts to the index."""
embeddings = self.embedder.embed_text(texts)
ids = [f"txt_{i}" for i in range(len(texts))]
meta = metadata or [{"text": t} for t in texts]
self.text_collection.add(
ids=ids,
embeddings=embeddings.tolist(),
metadatas=meta,
documents=texts
)
def search_images_by_text(
self,
query: str,
n: int = 5
) -> list[dict]:
"""
Find images matching a text query.
Text-to-Image search.
"""
query_embedding = self.embedder.embed_text(query)
results = self.image_collection.query(
query_embeddings=query_embedding.tolist(),
n_results=n
)
return [
{
"id": results["ids"][0][i],
"path": results["documents"][0][i],
"score": 1 - results["distances"][0][i],
"metadata": results["metadatas"][0][i]
}
for i in range(len(results["ids"][0]))
]
def search_texts_by_image(
self,
image_path: str,
n: int = 5
) -> list[dict]:
"""
Find texts describing an image.
Image-to-Text search.
"""
query_embedding = self.embedder.embed_image(image_path)
results = self.text_collection.query(
query_embeddings=query_embedding.tolist(),
n_results=n
)
return [
{
"id": results["ids"][0][i],
"text": results["documents"][0][i],
"score": 1 - results["distances"][0][i]
}
for i in range(len(results["ids"][0]))
]
def search_similar_images(
self,
image_path: str,
n: int = 5
) -> list[dict]:
"""
Find visually similar images.
Image-to-Image search.
"""
query_embedding = self.embedder.embed_image(image_path)
results = self.image_collection.query(
query_embeddings=query_embedding.tolist(),
n_results=n + 1 # +1 to exclude self
)
# Filter out the query image itself
return [
{
"id": results["ids"][0][i],
"path": results["documents"][0][i],
"score": 1 - results["distances"][0][i],
"metadata": results["metadatas"][0][i]
}
for i in range(len(results["ids"][0]))
if results["documents"][0][i] != image_path
][:n]
# Example usage
if __name__ == "__main__":
embedder = CLIPEmbedder()
# Embed text
text_emb = embedder.embed_text("a photo of a cat")
print(f"Text embedding shape: {text_emb.shape}")
# Embed image (if you have one)
# image_emb = embedder.embed_image("cat.jpg")
# similarity = embedder.similarity(text_emb, image_emb)
# print(f"Similarity: {similarity[0][0]:.4f}")What's Happening in MultiModalIndex?
The MultiModalIndex class handles vector storage and search. Let's trace through a complete search operation:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Text-to-Image Search Flow │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ User Query: "a sunset over the ocean" │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ Step 1: Embed Query ││
│ │ ┌─────────────────────────────────────────────────────────────────────┐ ││
│ │ │ embed_text("a sunset over the ocean") → [0.15, 0.32, -0.08, ...] │ ││
│ │ │ 512-dimensional vector │ ││
│ │ └─────────────────────────────────────────────────────────────────────┘ ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ Step 2: Search Image Collection (ChromaDB) ││
│ │ ┌─────────────────────────────────────────────────────────────────────┐ ││
│ │ │ HNSW Index finds nearest neighbors in O(log n) time │ ││
│ │ │ │ ││
│ │ │ Query Vector ────► Compare with stored image embeddings │ ││
│ │ │ │ │ ││
│ │ │ ▼ │ ││
│ │ │ 🌅 sunset.jpg [0.14, 0.31, -0.07, ...] → distance: 0.05 │ ││
│ │ │ 🏔️ mountain.jpg [0.22, 0.11, 0.15, ...] → distance: 0.72 │ ││
│ │ │ 🐱 cat.jpg [-0.05, 0.28, 0.33, ...] → distance: 0.89 │ ││
│ │ └─────────────────────────────────────────────────────────────────────┘ ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ Step 3: Return Top Results ││
│ │ ┌─────────────────────────────────────────────────────────────────────┐ ││
│ │ │ [{"path": "sunset.jpg", "score": 0.95}, ◄── 1 - distance = score │ ││
│ │ │ {"path": "beach.jpg", "score": 0.87}, │ ││
│ │ │ {"path": "sky.jpg", "score": 0.82}] │ ││
│ │ └─────────────────────────────────────────────────────────────────────┘ ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │
└─────────────────────────────────────────────────────────────────────────────┘ChromaDB Collection Configuration:
| Setting | Value | Why This Matters |
|---|---|---|
hnsw:space: "cosine" | Cosine distance | CLIP embeddings are normalized, so cosine = dot product. Gives scores in [0, 2] range (0 = identical, 2 = opposite). |
| Separate collections | images, texts | Allows independent scaling and different metadata schemas per modality. |
n_results + 1 for similar images | Excludes self | When searching for similar images, we don't want to return the query image itself. |
Understanding the Score Calculation:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Score vs Distance Relationship │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ChromaDB returns: distance (lower = better) │
│ We convert to: score = 1 - distance (higher = better) │
│ │
│ For cosine distance: │
│ ┌────────────────────────────────────────────────────────────────────────┐│
│ │ Distance 0.0 ──────────────────────────────────────────── Distance 2.0│
│ │ │ │ │ │
│ │ ▼ ▼ ▼ │
│ │ Identical Orthogonal Opposite │
│ │ (score=1.0) (score=0.0) (score=-1.0) │
│ │ │ │ │ │
│ │ Perfect match Unrelated Antonym │
│ │ "dog" ↔ 🐕 "dog" ↔ 🌅 "good" ↔ "bad"│
│ └────────────────────────────────────────────────────────────────────────┘│
│ │
│ Typical score ranges in practice: │
│ • 0.85+ : Strong match (the text describes the image well) │
│ • 0.70-0.85: Moderate match (related but not exact) │
│ • Below 0.70: Weak match (probably not what you're looking for) │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Step 3: Image Processing
"""
Image processing utilities for CLIP.
"""
from PIL import Image
import io
import base64
from pathlib import Path
from typing import Union, Optional
import numpy as np
def load_image(source: Union[str, bytes, Image.Image]) -> Image.Image:
"""
Load image from various sources.
Args:
source: File path, bytes, or PIL Image
"""
if isinstance(source, Image.Image):
return source
elif isinstance(source, bytes):
return Image.open(io.BytesIO(source))
elif isinstance(source, str):
if source.startswith("data:image"):
# Base64 encoded
header, data = source.split(",", 1)
return Image.open(io.BytesIO(base64.b64decode(data)))
else:
# File path
return Image.open(source)
else:
raise ValueError(f"Unsupported image source type: {type(source)}")
def resize_image(
image: Image.Image,
max_size: int = 512,
keep_aspect: bool = True
) -> Image.Image:
"""
Resize image while maintaining aspect ratio.
"""
if keep_aspect:
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
return image
else:
return image.resize((max_size, max_size), Image.Resampling.LANCZOS)
def convert_to_rgb(image: Image.Image) -> Image.Image:
"""Convert image to RGB mode (required for CLIP)."""
if image.mode != "RGB":
return image.convert("RGB")
return image
def image_to_base64(image: Image.Image, format: str = "PNG") -> str:
"""Convert PIL Image to base64 string."""
buffer = io.BytesIO()
image.save(buffer, format=format)
return base64.b64encode(buffer.getvalue()).decode()
def process_image_for_clip(
source: Union[str, bytes, Image.Image]
) -> Image.Image:
"""
Prepare image for CLIP embedding.
- Load from various sources
- Convert to RGB
- Resize if too large
"""
image = load_image(source)
image = convert_to_rgb(image)
image = resize_image(image, max_size=512)
return image
def create_image_grid(
images: list[Image.Image],
cols: int = 4,
cell_size: int = 256
) -> Image.Image:
"""
Create a grid of images for visualization.
"""
n_images = len(images)
rows = (n_images + cols - 1) // cols
grid = Image.new(
"RGB",
(cols * cell_size, rows * cell_size),
color="white"
)
for i, img in enumerate(images):
row = i // cols
col = i % cols
# Resize to cell size
img_resized = img.copy()
img_resized.thumbnail((cell_size, cell_size))
# Center in cell
x = col * cell_size + (cell_size - img_resized.width) // 2
y = row * cell_size + (cell_size - img_resized.height) // 2
grid.paste(img_resized, (x, y))
return gridWhat's Happening in Image Processing?
These utilities handle the messy reality of working with images from various sources:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Image Loading - Handle Any Source │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ load_image() accepts: │
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ File Path │ │ Raw Bytes │ │ Base64 String │ │
│ │ "/img/cat.jpg" │ │ b'\x89PNG...' │ │ "data:image/..."│ │
│ └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ PIL.Image.open() ││
│ │ │ ││
│ │ ▼ ││
│ │ PIL Image Object ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │
│ Why this flexibility? │
│ • File paths: Batch processing local images │
│ • Raw bytes: API uploads (FastAPI file uploads) │
│ • Base64: Web browser uploads, data URLs │
│ │
└─────────────────────────────────────────────────────────────────────────────┘The process_image_for_clip() Pipeline:
| Step | Function | Why It's Needed |
|---|---|---|
| 1. Load | load_image() | Handle various input formats uniformly |
| 2. RGB Convert | convert_to_rgb() | CLIP requires 3-channel RGB. Handles RGBA (transparency), grayscale, CMYK. |
| 3. Resize | resize_image(max_size=512) | Prevents memory issues with huge images. CLIP will resize to 224x224 anyway. |
Why Convert to RGB?
┌─────────────────────────────────────────────────────────────────────────────┐
│ Common Image Modes and Conversions │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Input Mode Channels CLIP Compatible? Conversion │
│ ───────────────────────────────────────────────────────────────────────── │
│ RGB 3 ✓ Yes None needed │
│ RGBA (PNG) 4 ✗ No Drop alpha channel │
│ L (Grayscale) 1 ✗ No Duplicate to RGB │
│ P (Palette) 1 ✗ No Convert via palette │
│ CMYK (Print) 4 ✗ No Convert to RGB │
│ │
│ Example RGBA → RGB: │
│ ┌────────────────────────────────────────────────────────────────────────┐│
│ │ Original: [R, G, B, A] = [255, 128, 64, 200] (semi-transparent orange) │
│ │ │ │
│ │ ▼ │
│ │ Converted: [R, G, B] = [255, 128, 64] (drop alpha, keep colors) │
│ └────────────────────────────────────────────────────────────────────────┘│
│ │
└─────────────────────────────────────────────────────────────────────────────┘Step 4: FastAPI Application
"""
FastAPI application for multi-modal search.
"""
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
from typing import Optional
import tempfile
from pathlib import Path
from .clip_embeddings import CLIPEmbedder, MultiModalIndex
from .image_processor import process_image_for_clip, image_to_base64
app = FastAPI(
title="Multi-Modal Search API",
description="Search across text and images using CLIP",
version="1.0.0"
)
# Initialize
embedder = CLIPEmbedder()
index = MultiModalIndex(embedder)
class TextSearchRequest(BaseModel):
query: str
n_results: int = 5
class TextSearchResult(BaseModel):
id: str
path: str
score: float
class SimilarityRequest(BaseModel):
text: str
image_path: str
@app.post("/index/images")
async def index_images(image_paths: list[str]):
"""Index a list of image files."""
# Validate paths
valid_paths = [p for p in image_paths if Path(p).exists()]
if not valid_paths:
raise HTTPException(status_code=400, detail="No valid image paths")
index.add_images(valid_paths)
return {"message": f"Indexed {len(valid_paths)} images"}
@app.post("/index/texts")
async def index_texts(texts: list[str]):
"""Index a list of text documents."""
index.add_texts(texts)
return {"message": f"Indexed {len(texts)} texts"}
@app.post("/search/text-to-image")
async def search_images_by_text(request: TextSearchRequest):
"""
Find images matching a text description.
Example: "a sunset over the ocean" -> finds sunset photos
"""
results = index.search_images_by_text(
request.query,
n=request.n_results
)
return {"query": request.query, "results": results}
@app.post("/search/image-to-text")
async def search_texts_by_image(file: UploadFile = File(...), n: int = 5):
"""
Find text descriptions for an uploaded image.
"""
# Save uploaded file temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
results = index.search_texts_by_image(tmp_path, n=n)
return {"results": results}
finally:
Path(tmp_path).unlink()
@app.post("/search/similar-images")
async def search_similar_images(file: UploadFile = File(...), n: int = 5):
"""
Find visually similar images to the uploaded image.
"""
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
content = await file.read()
tmp.write(content)
tmp_path = tmp.name
try:
results = index.search_similar_images(tmp_path, n=n)
return {"results": results}
finally:
Path(tmp_path).unlink()
@app.post("/similarity")
async def compute_similarity(request: SimilarityRequest):
"""
Compute similarity between a text and an image.
Returns a score from 0-1 indicating how well the text describes the image.
"""
text_emb = embedder.embed_text(request.text)
image_emb = embedder.embed_image(request.image_path)
similarity = float(embedder.similarity(text_emb, image_emb)[0][0])
return {
"text": request.text,
"image": request.image_path,
"similarity": similarity
}
@app.get("/stats")
async def get_stats():
"""Get index statistics."""
return {
"model": "ViT-B-32",
"embedding_dim": embedder.embedding_dim,
"images_indexed": index.image_collection.count(),
"texts_indexed": index.text_collection.count()
}What's Happening in the API?
The FastAPI application exposes multi-modal search as REST endpoints:
┌─────────────────────────────────────────────────────────────────────────────┐
│ API Endpoint Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Indexing Endpoints (Build the searchable index): │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ POST /index/images ││
│ │ {"paths": ["/data/img1.jpg", "/data/img2.jpg"]} ││
│ │ │ ││
│ │ ▼ ││
│ │ • Validate paths exist ││
│ │ • Generate CLIP embeddings for each image ││
│ │ • Store in ChromaDB with path as metadata ││
│ │ │ ││
│ │ ▼ ││
│ │ {"message": "Indexed 2 images"} ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │
│ Search Endpoints (Find matching content): │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ ││
│ │ /search/text-to-image /search/image-to-text /search/similar ││
│ │ │ │ │ ││
│ │ ▼ ▼ ▼ ││
│ │ ┌───────────┐ ┌───────────┐ ┌───────────┐ ││
│ │ │ Text │ │ Image │ │ Image │ ││
│ │ │ Query │ │ Upload │ │ Upload │ ││
│ │ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ ││
│ │ │ │ │ ││
│ │ ▼ ▼ ▼ ││
│ │ Search Images Search Texts Search Images ││
│ │ Collection Collection (exclude self) ││
│ │ ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │
└─────────────────────────────────────────────────────────────────────────────┘API Design Decisions:
| Pattern | Why It's Used |
|---|---|
UploadFile for image search | Allows direct file upload instead of requiring paths. Users can search with any image. |
tempfile.NamedTemporaryFile | Temporarily saves uploaded files for CLIP processing, then cleans up. |
finally: Path(tmp_path).unlink() | Ensures temp files are deleted even if an error occurs. Prevents disk space leaks. |
validate_paths before indexing | Fail fast with clear error message instead of cryptic CLIP errors. |
Global embedder, index | Initialize once at startup. CLIP model loading takes several seconds. |
Understanding the Similarity Endpoint:
┌─────────────────────────────────────────────────────────────────────────────┐
│ /similarity - How Well Does Text Match Image? │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ Request: │
│ { │
│ "text": "a fluffy orange cat sleeping on a couch", │
│ "image_path": "/uploads/cat_napping.jpg" │
│ } │
│ │
│ Process: │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ ││
│ │ Text ──► embed_text() ──► [0.12, 0.34, ...] ─┐ ││
│ │ │ ││
│ │ ├──► dot product ──► 0.87│
│ │ │ ││
│ │ Image ──► embed_image() ──► [0.11, 0.35, ...] ─┘ ││
│ │ ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │
│ Response: │
│ { │
│ "text": "a fluffy orange cat sleeping on a couch", │
│ "image": "/uploads/cat_napping.jpg", │
│ "similarity": 0.87 ◄── High score! Text describes image well │
│ } │
│ │
│ Use Cases: │
│ • Validate auto-generated captions │
│ • Filter user-uploaded content │
│ • Score image-text relevance for ranking │
│ │
└─────────────────────────────────────────────────────────────────────────────┘Step 5: Streamlit Demo
"""
Streamlit demo for multi-modal search.
"""
import streamlit as st
from PIL import Image
import tempfile
from pathlib import Path
from src.clip_embeddings import CLIPEmbedder, MultiModalIndex
from src.image_processor import create_image_grid
st.set_page_config(
page_title="Multi-Modal Search",
page_icon="🖼️",
layout="wide"
)
@st.cache_resource
def load_model():
embedder = CLIPEmbedder()
index = MultiModalIndex(embedder)
return embedder, index
def main():
st.title("🖼️ Multi-Modal Search with CLIP")
st.markdown("Search images with text, or find similar images")
embedder, index = load_model()
tab1, tab2, tab3 = st.tabs([
"📝 Text-to-Image",
"🖼️ Image-to-Image",
"📊 Add Images"
])
with tab1:
st.subheader("Find images by description")
query = st.text_input(
"Enter description",
placeholder="a photo of a sunset"
)
if st.button("Search", key="text_search"):
if query and index.image_collection.count() > 0:
results = index.search_images_by_text(query, n=6)
if results:
cols = st.columns(3)
for i, result in enumerate(results):
with cols[i % 3]:
try:
img = Image.open(result["path"])
st.image(img, caption=f"Score: {result['score']:.3f}")
except:
st.error(f"Could not load: {result['path']}")
else:
st.warning("No results found")
else:
st.warning("Add some images first!")
with tab2:
st.subheader("Find similar images")
uploaded = st.file_uploader(
"Upload an image",
type=["jpg", "jpeg", "png"]
)
if uploaded:
# Show uploaded image
image = Image.open(uploaded)
st.image(image, caption="Query image", width=300)
if st.button("Find Similar", key="image_search"):
# Save temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
image.save(tmp.name)
tmp_path = tmp.name
results = index.search_similar_images(tmp_path, n=6)
Path(tmp_path).unlink()
if results:
st.subheader("Similar images:")
cols = st.columns(3)
for i, result in enumerate(results):
with cols[i % 3]:
try:
img = Image.open(result["path"])
st.image(img, caption=f"Score: {result['score']:.3f}")
except:
st.error(f"Could not load: {result['path']}")
else:
st.warning("No similar images found")
with tab3:
st.subheader("Add images to index")
uploaded_files = st.file_uploader(
"Upload images",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True
)
if uploaded_files and st.button("Index Images"):
paths = []
for uploaded in uploaded_files:
# Save to temp directory
temp_dir = Path(tempfile.gettempdir()) / "clip_images"
temp_dir.mkdir(exist_ok=True)
path = temp_dir / uploaded.name
with open(path, "wb") as f:
f.write(uploaded.getbuffer())
paths.append(str(path))
with st.spinner("Indexing..."):
index.add_images(paths)
st.success(f"Indexed {len(paths)} images!")
st.metric("Images indexed", index.image_collection.count())
if __name__ == "__main__":
main()What's Happening in the Streamlit App?
The demo provides an interactive UI for testing multi-modal search:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Streamlit App Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ @st.cache_resource │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ Why cache the model? ││
│ │ ││
│ │ Without caching: ││
│ │ User clicks button → Reload CLIP (5+ seconds) → Search → Results ││
│ │ User clicks again → Reload CLIP (5+ seconds) → Search → Results ││
│ │ ││
│ │ With caching: ││
│ │ First load → Load CLIP (5+ seconds) → Cache in memory ││
│ │ User clicks → Instant search → Results ││
│ │ User clicks → Instant search → Results ││
│ │ ││
│ │ cache_resource is for objects that shouldn't be serialized (ML models) ││
│ │ cache_data is for data that can be pickled (DataFrames, dicts) ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │
│ Tab Structure: │
│ ┌─────────────────────────────────────────────────────────────────────────┐│
│ │ ││
│ │ ┌────────────────┬─────────────────┬────────────────┐ ││
│ │ │ 📝 Text-to- │ 🖼️ Image-to- │ 📊 Add Images │ ││
│ │ │ Image │ Image │ │ ││
│ │ ├────────────────┴─────────────────┴────────────────┤ ││
│ │ │ │ ││
│ │ │ Tab 1: Type description → Find matching images │ ││
│ │ │ │ ││
│ │ │ Tab 2: Upload image → Find similar images │ ││
│ │ │ │ ││
│ │ │ Tab 3: Upload images → Index for searching │ ││
│ │ │ │ ││
│ │ └───────────────────────────────────────────────────┘ ││
│ │ ││
│ └─────────────────────────────────────────────────────────────────────────┘│
│ │
└─────────────────────────────────────────────────────────────────────────────┘Key Streamlit Patterns Explained:
| Pattern | Why It's Used |
|---|---|
st.columns(3) | Display search results in a responsive grid. Each image gets equal width. |
cols[i % 3] | Distribute images across columns: 0→col1, 1→col2, 2→col3, 3→col1, ... |
st.spinner("Indexing...") | Show loading state while embedding images. Users know the app isn't frozen. |
st.metric() | Display the count as a prominent statistic. Better than plain text for numbers. |
try/except around image loading | Gracefully handle missing or corrupt image files without crashing the app. |
Running the Application
# Install dependencies
pip install -r requirements.txt
# Run the API
uvicorn src.api:app --reload --port 8000
# Run the demo
streamlit run app.pyKey Concepts
CLIP Model Variants
| Model | Dim | Speed | Quality |
|---|---|---|---|
| ViT-B-32 | 512 | Fast | Good |
| ViT-B-16 | 512 | Medium | Better |
| ViT-L-14 | 768 | Slow | Best |
Use Cases
- E-commerce: Search products by description
- Stock Photos: Find images matching keywords
- Content Moderation: Detect specific content
- Accessibility: Generate image descriptions
Key Concepts Recap
| Concept | What It Is | Why It Matters |
|---|---|---|
| CLIP | Model that embeds text and images together | Enables cross-modal search without labels |
| Contrastive Learning | Training by matching pairs | Creates aligned embedding spaces |
| Text-to-Image Search | Query images with text descriptions | Natural language interface for image search |
| Image-to-Image Search | Find visually similar images | Duplicate detection, recommendations |
| ViT (Vision Transformer) | Transformer architecture for images | CLIP's image encoder |
| Shared Embedding Space | Same vector space for both modalities | Makes cross-modal similarity possible |
| OpenCLIP | Open-source CLIP implementations | Production-ready, various model sizes |
Next Steps
- Search at Scale - Scale to billions of vectors
- Production Pipeline - Deploy for production