You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
120 lines
3.9 KiB
120 lines
3.9 KiB
"""
|
|
embedder.py — Embedding generation via Ollama or sentence-transformers fallback.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import numpy as np
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Dimensionality per model
|
|
_MODEL_DIMS: dict[str, int] = {
|
|
'nomic-embed-text': 768,
|
|
'all-minilm-l6-v2': 384,
|
|
'mxbai-embed-large': 1024,
|
|
}
|
|
|
|
|
|
class OllamaEmbedder:
|
|
"""Generate embeddings via the Ollama /api/embed endpoint."""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str = 'http://ollama:11434',
|
|
model: str = 'nomic-embed-text',
|
|
timeout: float = 60.0,
|
|
batch_size: int = 32,
|
|
) -> None:
|
|
self.base_url = base_url.rstrip('/')
|
|
self.model = model
|
|
self.timeout = timeout
|
|
self.batch_size = batch_size
|
|
self.dimensions = _MODEL_DIMS.get(model, 768)
|
|
|
|
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
"""Embed a list of texts, returning a list of float vectors."""
|
|
all_embeddings: list[list[float]] = []
|
|
|
|
for i in range(0, len(texts), self.batch_size):
|
|
batch = texts[i : i + self.batch_size]
|
|
embeddings = self._call_ollama(batch)
|
|
all_embeddings.extend(embeddings)
|
|
|
|
return all_embeddings
|
|
|
|
def embed_single(self, text: str) -> list[float]:
|
|
return self.embed_batch([text])[0]
|
|
|
|
def _call_ollama(self, texts: list[str], retries: int = 3) -> list[list[float]]:
|
|
url = f'{self.base_url}/api/embed'
|
|
payload: dict[str, Any] = {'model': self.model, 'input': texts}
|
|
|
|
for attempt in range(1, retries + 1):
|
|
try:
|
|
with httpx.Client(timeout=self.timeout) as client:
|
|
resp = client.post(url, json=payload)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
return data['embeddings']
|
|
except (httpx.HTTPError, KeyError) as exc:
|
|
logger.warning('Ollama embed attempt %d/%d failed: %s', attempt, retries, exc)
|
|
if attempt < retries:
|
|
time.sleep(2 ** attempt) # exponential backoff
|
|
else:
|
|
raise
|
|
|
|
|
|
class SentenceTransformerEmbedder:
|
|
"""Local fallback embedder using sentence-transformers."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str = 'all-MiniLM-L6-v2',
|
|
batch_size: int = 32,
|
|
) -> None:
|
|
# Lazy import so the module loads even if not installed
|
|
try:
|
|
from sentence_transformers import SentenceTransformer # type: ignore
|
|
except ImportError as exc:
|
|
raise ImportError(
|
|
'sentence-transformers is required for the local fallback embedder. '
|
|
'Install it with: pip install sentence-transformers'
|
|
) from exc
|
|
|
|
logger.info('Loading sentence-transformer model: %s', model_name)
|
|
self._model = SentenceTransformer(model_name)
|
|
self.batch_size = batch_size
|
|
self.dimensions = self._model.get_sentence_embedding_dimension()
|
|
|
|
def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
|
vectors = self._model.encode(
|
|
texts,
|
|
batch_size=self.batch_size,
|
|
show_progress_bar=False,
|
|
normalize_embeddings=True,
|
|
)
|
|
return [v.tolist() for v in vectors]
|
|
|
|
def embed_single(self, text: str) -> list[float]:
|
|
return self.embed_batch([text])[0]
|
|
|
|
|
|
def get_embedder(
|
|
provider: str = 'ollama',
|
|
ollama_url: str = 'http://ollama:11434',
|
|
model: str = 'nomic-embed-text',
|
|
) -> OllamaEmbedder | SentenceTransformerEmbedder:
|
|
"""Factory function returning the configured embedder."""
|
|
if provider == 'ollama':
|
|
return OllamaEmbedder(base_url=ollama_url, model=model)
|
|
elif provider == 'sentence_transformers':
|
|
return SentenceTransformerEmbedder(model_name=model)
|
|
else:
|
|
raise ValueError(f'Unknown embedding provider: {provider!r}')
|