mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-19 19:30:05 +00:00
feat: Implement hybrid search in SQLite-vec
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
941f505eb0
commit
eab85a7121
13 changed files with 335 additions and 10 deletions
|
@ -131,6 +131,15 @@ class FaissIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
|
|
|
@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
|
|||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
KEYWORD_SEARCH = "keyword"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||
HYBRID_SEARCH = "hybrid"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||
|
||||
|
||||
def serialize_vector(vector: list[float]) -> bytes:
|
||||
|
@ -255,8 +256,6 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
if query_string is None:
|
||||
raise ValueError("query_string is required for keyword search.")
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
|
@ -294,6 +293,69 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Hybrid search using Reciprocal Rank Fusion (RRF) to combine vector and keyword search results.
|
||||
RRF assigns scores based on the reciprocal of the rank position in each search method,
|
||||
then combines these scores to get a final ranking.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k * 2, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k * 2, score_threshold)
|
||||
|
||||
# Create dictionaries to store ranks for each method
|
||||
vector_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(vector_response.chunks)}
|
||||
keyword_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(keyword_response.chunks)}
|
||||
|
||||
# Calculate RRF scores for all unique document IDs
|
||||
all_ids = set(vector_ranks.keys()) | set(keyword_ranks.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is a constant and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (60 + vector_rank)) + (1.0 / (60 + keyword_rank))
|
||||
|
||||
# Sort by RRF score and get top k results
|
||||
sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)[:k]
|
||||
|
||||
# Combine results maintaining RRF scores
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id in sorted_ids:
|
||||
score = rrf_scores[doc_id]
|
||||
if score >= score_threshold:
|
||||
# Try to get from vector results first
|
||||
for chunk in vector_response.chunks:
|
||||
if chunk.metadata["document_id"] == doc_id:
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
break
|
||||
else:
|
||||
# If not in vector results, get from keyword results
|
||||
for chunk in keyword_response.chunks:
|
||||
if chunk.metadata["document_id"] == doc_id:
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
break
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
|
@ -345,7 +407,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
vector_db_data = row[0]
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
|
@ -371,7 +435,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_register_db)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue