You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
4.5 KiB

"""
services/retriever.py — Hybrid vector + full-text search against PostgreSQL.
"""
from __future__ import annotations
import logging
import time
from typing import Optional
import asyncpg
from models.responses import ChunkResult
logger = logging.getLogger(__name__)
async def hybrid_search(
conn: asyncpg.Connection,
query: str,
embedding: list[float],
limit: int = 10,
threshold: float = 0.65,
tags: Optional[list[str]] = None,
) -> tuple[list[ChunkResult], float]:
"""
Hybrid search: vector similarity + full-text search, merged by RRF.
Returns (results, query_time_ms).
"""
start = time.monotonic()
tag_filter = ''
params: list = [embedding, query, limit * 2, threshold]
if tags:
tag_filter = 'AND d.tags && $5'
params.append(tags)
# Combined RRF (Reciprocal Rank Fusion) of vector and FTS results
sql = f"""
WITH vector_results AS (
SELECT
c.id AS chunk_id,
c.document_id,
c.content,
c.chunk_index,
1 - (c.embedding <=> $1::vector) AS vector_score,
ROW_NUMBER() OVER (ORDER BY c.embedding <=> $1::vector) AS vector_rank
FROM chunks c
JOIN documents d ON d.id = c.document_id
WHERE 1 - (c.embedding <=> $1::vector) >= $4
{tag_filter}
ORDER BY c.embedding <=> $1::vector
LIMIT $3
),
fts_results AS (
SELECT
c.id AS chunk_id,
c.document_id,
c.content,
c.chunk_index,
ts_rank_cd(d.fts_vector, plainto_tsquery('english', $2)) AS fts_score,
ROW_NUMBER() OVER (
ORDER BY ts_rank_cd(d.fts_vector, plainto_tsquery('english', $2)) DESC
) AS fts_rank
FROM chunks c
JOIN documents d ON d.id = c.document_id
WHERE d.fts_vector @@ plainto_tsquery('english', $2)
{tag_filter}
ORDER BY fts_score DESC
LIMIT $3
),
merged AS (
SELECT
COALESCE(v.chunk_id, f.chunk_id) AS chunk_id,
COALESCE(v.document_id, f.document_id) AS document_id,
COALESCE(v.content, f.content) AS content,
(COALESCE(1.0 / (60 + v.vector_rank), 0) +
COALESCE(1.0 / (60 + f.fts_rank), 0)) AS rrf_score,
COALESCE(v.vector_score, 0) AS vector_score
FROM vector_results v
FULL OUTER JOIN fts_results f ON v.chunk_id = f.chunk_id
)
SELECT
m.chunk_id::text,
m.document_id::text,
m.content,
m.rrf_score,
m.vector_score,
d.title,
d.path,
d.tags,
ts_headline('english', m.content, plainto_tsquery('english', $2),
'MaxWords=20, MinWords=10, ShortWord=3') AS highlight
FROM merged m
JOIN documents d ON d.id = m.document_id
ORDER BY m.rrf_score DESC
LIMIT $3
"""
rows = await conn.fetch(sql, *params)
elapsed_ms = (time.monotonic() - start) * 1000
results = [
ChunkResult(
chunk_id=str(row['chunk_id']),
document_id=str(row['document_id']),
title=row['title'] or '',
path=row['path'],
content=row['content'],
score=round(float(row['rrf_score']), 4),
tags=list(row['tags'] or []),
highlight=row['highlight'],
)
for row in rows
]
return results[:limit], round(elapsed_ms, 2)
async def get_related(
conn: asyncpg.Connection,
document_id: str,
limit: int = 5,
) -> list[dict]:
"""Find documents related to the given document via average chunk embedding."""
rows = await conn.fetch(
"""
WITH doc_embedding AS (
SELECT AVG(embedding) AS avg_emb
FROM chunks
WHERE document_id = $1::uuid
)
SELECT
d.id::text,
d.title,
d.path,
d.tags,
1 - (AVG(c.embedding) <=> (SELECT avg_emb FROM doc_embedding)) AS score
FROM chunks c
JOIN documents d ON d.id = c.document_id
WHERE c.document_id != $1::uuid
GROUP BY d.id, d.title, d.path, d.tags
ORDER BY score DESC
LIMIT $2
""",
document_id,
limit,
)
return [
{
'document_id': row['id'],
'title': row['title'] or '',
'path': row['path'],
'tags': list(row['tags'] or []),
'score': round(float(row['score']), 4),
}
for row in rows
]

Powered by TurnKey Linux.