You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

120 lines
3.9 KiB

"""
embedder.py — Embedding generation via Ollama or sentence-transformers fallback.
"""
from __future__ import annotations
import logging
import time
from typing import Any
import httpx
import numpy as np
logger = logging.getLogger(__name__)
# Dimensionality per model
_MODEL_DIMS: dict[str, int] = {
'nomic-embed-text': 768,
'all-minilm-l6-v2': 384,
'mxbai-embed-large': 1024,
}
class OllamaEmbedder:
"""Generate embeddings via the Ollama /api/embed endpoint."""
def __init__(
self,
base_url: str = 'http://ollama:11434',
model: str = 'nomic-embed-text',
timeout: float = 60.0,
batch_size: int = 32,
) -> None:
self.base_url = base_url.rstrip('/')
self.model = model
self.timeout = timeout
self.batch_size = batch_size
self.dimensions = _MODEL_DIMS.get(model, 768)
def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts, returning a list of float vectors."""
all_embeddings: list[list[float]] = []
for i in range(0, len(texts), self.batch_size):
batch = texts[i : i + self.batch_size]
embeddings = self._call_ollama(batch)
all_embeddings.extend(embeddings)
return all_embeddings
def embed_single(self, text: str) -> list[float]:
return self.embed_batch([text])[0]
def _call_ollama(self, texts: list[str], retries: int = 3) -> list[list[float]]:
url = f'{self.base_url}/api/embed'
payload: dict[str, Any] = {'model': self.model, 'input': texts}
for attempt in range(1, retries + 1):
try:
with httpx.Client(timeout=self.timeout) as client:
resp = client.post(url, json=payload)
resp.raise_for_status()
data = resp.json()
return data['embeddings']
except (httpx.HTTPError, KeyError) as exc:
logger.warning('Ollama embed attempt %d/%d failed: %s', attempt, retries, exc)
if attempt < retries:
time.sleep(2 ** attempt) # exponential backoff
else:
raise
class SentenceTransformerEmbedder:
"""Local fallback embedder using sentence-transformers."""
def __init__(
self,
model_name: str = 'all-MiniLM-L6-v2',
batch_size: int = 32,
) -> None:
# Lazy import so the module loads even if not installed
try:
from sentence_transformers import SentenceTransformer # type: ignore
except ImportError as exc:
raise ImportError(
'sentence-transformers is required for the local fallback embedder. '
'Install it with: pip install sentence-transformers'
) from exc
logger.info('Loading sentence-transformer model: %s', model_name)
self._model = SentenceTransformer(model_name)
self.batch_size = batch_size
self.dimensions = self._model.get_sentence_embedding_dimension()
def embed_batch(self, texts: list[str]) -> list[list[float]]:
vectors = self._model.encode(
texts,
batch_size=self.batch_size,
show_progress_bar=False,
normalize_embeddings=True,
)
return [v.tolist() for v in vectors]
def embed_single(self, text: str) -> list[float]:
return self.embed_batch([text])[0]
def get_embedder(
provider: str = 'ollama',
ollama_url: str = 'http://ollama:11434',
model: str = 'nomic-embed-text',
) -> OllamaEmbedder | SentenceTransformerEmbedder:
"""Factory function returning the configured embedder."""
if provider == 'ollama':
return OllamaEmbedder(base_url=ollama_url, model=model)
elif provider == 'sentence_transformers':
return SentenceTransformerEmbedder(model_name=model)
else:
raise ValueError(f'Unknown embedding provider: {provider!r}')

Powered by TurnKey Linux.