You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
4.5 KiB
161 lines
4.5 KiB
"""
|
|
services/retriever.py — Hybrid vector + full-text search against PostgreSQL.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from typing import Optional
|
|
|
|
import asyncpg
|
|
|
|
from models.responses import ChunkResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def hybrid_search(
|
|
conn: asyncpg.Connection,
|
|
query: str,
|
|
embedding: list[float],
|
|
limit: int = 10,
|
|
threshold: float = 0.65,
|
|
tags: Optional[list[str]] = None,
|
|
) -> tuple[list[ChunkResult], float]:
|
|
"""
|
|
Hybrid search: vector similarity + full-text search, merged by RRF.
|
|
|
|
Returns (results, query_time_ms).
|
|
"""
|
|
start = time.monotonic()
|
|
|
|
tag_filter = ''
|
|
params: list = [embedding, query, limit * 2, threshold]
|
|
|
|
if tags:
|
|
tag_filter = 'AND d.tags && $5'
|
|
params.append(tags)
|
|
|
|
# Combined RRF (Reciprocal Rank Fusion) of vector and FTS results
|
|
sql = f"""
|
|
WITH vector_results AS (
|
|
SELECT
|
|
c.id AS chunk_id,
|
|
c.document_id,
|
|
c.content,
|
|
c.chunk_index,
|
|
1 - (c.embedding <=> $1::vector) AS vector_score,
|
|
ROW_NUMBER() OVER (ORDER BY c.embedding <=> $1::vector) AS vector_rank
|
|
FROM chunks c
|
|
JOIN documents d ON d.id = c.document_id
|
|
WHERE 1 - (c.embedding <=> $1::vector) >= $4
|
|
{tag_filter}
|
|
ORDER BY c.embedding <=> $1::vector
|
|
LIMIT $3
|
|
),
|
|
fts_results AS (
|
|
SELECT
|
|
c.id AS chunk_id,
|
|
c.document_id,
|
|
c.content,
|
|
c.chunk_index,
|
|
ts_rank_cd(d.fts_vector, plainto_tsquery('english', $2)) AS fts_score,
|
|
ROW_NUMBER() OVER (
|
|
ORDER BY ts_rank_cd(d.fts_vector, plainto_tsquery('english', $2)) DESC
|
|
) AS fts_rank
|
|
FROM chunks c
|
|
JOIN documents d ON d.id = c.document_id
|
|
WHERE d.fts_vector @@ plainto_tsquery('english', $2)
|
|
{tag_filter}
|
|
ORDER BY fts_score DESC
|
|
LIMIT $3
|
|
),
|
|
merged AS (
|
|
SELECT
|
|
COALESCE(v.chunk_id, f.chunk_id) AS chunk_id,
|
|
COALESCE(v.document_id, f.document_id) AS document_id,
|
|
COALESCE(v.content, f.content) AS content,
|
|
(COALESCE(1.0 / (60 + v.vector_rank), 0) +
|
|
COALESCE(1.0 / (60 + f.fts_rank), 0)) AS rrf_score,
|
|
COALESCE(v.vector_score, 0) AS vector_score
|
|
FROM vector_results v
|
|
FULL OUTER JOIN fts_results f ON v.chunk_id = f.chunk_id
|
|
)
|
|
SELECT
|
|
m.chunk_id::text,
|
|
m.document_id::text,
|
|
m.content,
|
|
m.rrf_score,
|
|
m.vector_score,
|
|
d.title,
|
|
d.path,
|
|
d.tags,
|
|
ts_headline('english', m.content, plainto_tsquery('english', $2),
|
|
'MaxWords=20, MinWords=10, ShortWord=3') AS highlight
|
|
FROM merged m
|
|
JOIN documents d ON d.id = m.document_id
|
|
ORDER BY m.rrf_score DESC
|
|
LIMIT $3
|
|
"""
|
|
|
|
rows = await conn.fetch(sql, *params)
|
|
elapsed_ms = (time.monotonic() - start) * 1000
|
|
|
|
results = [
|
|
ChunkResult(
|
|
chunk_id=str(row['chunk_id']),
|
|
document_id=str(row['document_id']),
|
|
title=row['title'] or '',
|
|
path=row['path'],
|
|
content=row['content'],
|
|
score=round(float(row['rrf_score']), 4),
|
|
tags=list(row['tags'] or []),
|
|
highlight=row['highlight'],
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
return results[:limit], round(elapsed_ms, 2)
|
|
|
|
|
|
async def get_related(
|
|
conn: asyncpg.Connection,
|
|
document_id: str,
|
|
limit: int = 5,
|
|
) -> list[dict]:
|
|
"""Find documents related to the given document via average chunk embedding."""
|
|
rows = await conn.fetch(
|
|
"""
|
|
WITH doc_embedding AS (
|
|
SELECT AVG(embedding) AS avg_emb
|
|
FROM chunks
|
|
WHERE document_id = $1::uuid
|
|
)
|
|
SELECT
|
|
d.id::text,
|
|
d.title,
|
|
d.path,
|
|
d.tags,
|
|
1 - (AVG(c.embedding) <=> (SELECT avg_emb FROM doc_embedding)) AS score
|
|
FROM chunks c
|
|
JOIN documents d ON d.id = c.document_id
|
|
WHERE c.document_id != $1::uuid
|
|
GROUP BY d.id, d.title, d.path, d.tags
|
|
ORDER BY score DESC
|
|
LIMIT $2
|
|
""",
|
|
document_id,
|
|
limit,
|
|
)
|
|
return [
|
|
{
|
|
'document_id': row['id'],
|
|
'title': row['title'] or '',
|
|
'path': row['path'],
|
|
'tags': list(row['tags'] or []),
|
|
'score': round(float(row['score']), 4),
|
|
}
|
|
for row in rows
|
|
]
|