From eab85a712107229144e5ae8512760e82fdf68932 Mon Sep 17 00:00:00 2001 From: Varsha Prasad Narsing Date: Thu, 29 May 2025 14:32:54 -0700 Subject: [PATCH] feat: Implement hybrid search in SQLite-vec Signed-off-by: Varsha Prasad Narsing --- docs/_static/llama-stack-spec.html | 2 +- docs/_static/llama-stack-spec.yaml | 3 +- docs/source/providers/vector_io/sqlite-vec.md | 33 ++++ llama_stack/apis/tools/rag_tool.py | 2 +- .../providers/inline/vector_io/faiss/faiss.py | 9 + .../inline/vector_io/sqlite_vec/sqlite_vec.py | 78 ++++++++- .../remote/vector_io/chroma/chroma.py | 9 + .../remote/vector_io/milvus/milvus.py | 9 + .../remote/vector_io/pgvector/pgvector.py | 9 + .../remote/vector_io/qdrant/qdrant.py | 9 + .../remote/vector_io/weaviate/weaviate.py | 9 + .../providers/utils/memory/vector_store.py | 15 +- .../providers/vector_io/test_sqlite_vec.py | 158 ++++++++++++++++++ 13 files changed, 335 insertions(+), 10 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ce47f8ebb..43be87464 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -13994,7 +13994,7 @@ }, "mode": { "type": "string", - "description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"." + "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 07a176b32..cd8c75527 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -9756,7 +9756,8 @@ components: mode: type: string description: >- - Search mode for retrieval—either "vector" or "keyword". Default "vector". + Search mode for retrieval—either "vector", "keyword", or "hybrid". Default + "vector". additionalProperties: false required: - query_generator_config diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md index 49ba659f7..4cfa08f18 100644 --- a/docs/source/providers/vector_io/sqlite-vec.md +++ b/docs/source/providers/vector_io/sqlite-vec.md @@ -66,6 +66,39 @@ To use sqlite-vec in your Llama Stack project, follow these steps: 2. Configure your Llama Stack project to use SQLite-Vec. 3. Start storing and querying vectors. +The SQLite-vec provider supports three search modes: + +1. **Vector Search** (`mode="vector"`): Performs pure vector similarity search using the embeddings. +2. **Keyword Search** (`mode="keyword"`): Performs full-text search using SQLite's FTS5. +3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword search for better results. First performs keyword search to get candidate matches, then applies vector similarity search on those candidates. + +Example with hybrid search: +```python +response = await vector_io.query_chunks( + vector_db_id="my_db", + query="your query here", + params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7}, +) +``` + +Example with explicit vector search: +```python +response = await vector_io.query_chunks( + vector_db_id="my_db", + query="your query here", + params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7}, +) +``` + +Example with keyword search: +```python +response = await vector_io.query_chunks( + vector_db_id="my_db", + query="your query here", + params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7}, +) +``` + ## Supported Search Modes The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes. diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 1e3542f74..e2ece0d91 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -76,7 +76,7 @@ class RAGQueryConfig(BaseModel): :param chunk_template: Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" - :param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector". + :param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector". """ # This config defines how a query is generated using the messages diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index afb911726..bb3dbf6da 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -131,6 +131,15 @@ class FaissIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in FAISS") + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in FAISS") + class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index f69cf8a32..22a4179cf 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -34,7 +34,8 @@ logger = logging.getLogger(__name__) # Specifying search mode is dependent on the VectorIO provider. VECTOR_SEARCH = "vector" KEYWORD_SEARCH = "keyword" -SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} +HYBRID_SEARCH = "hybrid" +SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH} def serialize_vector(vector: list[float]) -> bytes: @@ -255,8 +256,6 @@ class SQLiteVecIndex(EmbeddingIndex): """ Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search. """ - if query_string is None: - raise ValueError("query_string is required for keyword search.") def _execute_query(): connection = _create_sqlite_connection(self.db_path) @@ -294,6 +293,69 @@ class SQLiteVecIndex(EmbeddingIndex): scores.append(score) return QueryChunksResponse(chunks=chunks, scores=scores) + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """ + Hybrid search using Reciprocal Rank Fusion (RRF) to combine vector and keyword search results. + RRF assigns scores based on the reciprocal of the rank position in each search method, + then combines these scores to get a final ranking. + + Args: + embedding: The query embedding vector + query_string: The text query for keyword search + k: Number of results to return + score_threshold: Minimum similarity score threshold + + Returns: + QueryChunksResponse with combined results + """ + # Get results from both search methods + vector_response = await self.query_vector(embedding, k * 2, score_threshold) + keyword_response = await self.query_keyword(query_string, k * 2, score_threshold) + + # Create dictionaries to store ranks for each method + vector_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(vector_response.chunks)} + keyword_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(keyword_response.chunks)} + + # Calculate RRF scores for all unique document IDs + all_ids = set(vector_ranks.keys()) | set(keyword_ranks.keys()) + rrf_scores = {} + for doc_id in all_ids: + vector_rank = vector_ranks.get(doc_id, float("inf")) + keyword_rank = keyword_ranks.get(doc_id, float("inf")) + # RRF formula: score = 1/(k + r) where k is a constant and r is the rank + rrf_scores[doc_id] = (1.0 / (60 + vector_rank)) + (1.0 / (60 + keyword_rank)) + + # Sort by RRF score and get top k results + sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)[:k] + + # Combine results maintaining RRF scores + chunks = [] + scores = [] + for doc_id in sorted_ids: + score = rrf_scores[doc_id] + if score >= score_threshold: + # Try to get from vector results first + for chunk in vector_response.chunks: + if chunk.metadata["document_id"] == doc_id: + chunks.append(chunk) + scores.append(score) + break + else: + # If not in vector results, get from keyword results + for chunk in keyword_response.chunks: + if chunk.metadata["document_id"] == doc_id: + chunks.append(chunk) + scores.append(score) + break + + return QueryChunksResponse(chunks=chunks, scores=scores) + class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): """ @@ -345,7 +407,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc vector_db_data = row[0] vector_db = VectorDB.model_validate_json(vector_db_data) index = await SQLiteVecIndex.create( - vector_db.embedding_dimension, self.config.db_path, vector_db.identifier + vector_db.embedding_dimension, + self.config.db_path, + vector_db.identifier, ) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) @@ -371,7 +435,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc connection.close() await asyncio.to_thread(_register_db) - index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier) + index = await SQLiteVecIndex.create( + vector_db.embedding_dimension, + self.config.db_path, + vector_db.identifier, + ) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) async def list_vector_dbs(self) -> list[VectorDB]: diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index fee29cfd9..7571c4441 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -105,6 +105,15 @@ class ChromaIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Chroma") + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in Chroma") + class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 51c541c02..833224541 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -103,6 +103,15 @@ class MilvusIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Milvus") + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in Milvus") + class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 7d58a49f3..01548ee08 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -128,6 +128,15 @@ class PGVectorIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in PGVector") + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in PGVector") + async def delete(self): with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 1631a7a2a..b9bc63567 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -112,6 +112,15 @@ class QdrantIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Qdrant") + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in Qdrant") + async def delete(self): await self.client.delete_collection(collection_name=self.collection_name) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 6f2027dad..9d4a65045 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -92,6 +92,15 @@ class WeaviateIndex(EmbeddingIndex): ) -> QueryChunksResponse: raise NotImplementedError("Keyword search is not supported in Weaviate") + async def query_hybrid( + self, + embedding: NDArray, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + raise NotImplementedError("Hybrid search is not supported in Weaviate") + class WeaviateVectorIOAdapter( VectorIO, diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 2c0c7c8e9..a5875acf8 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -202,6 +202,12 @@ class EmbeddingIndex(ABC): async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: raise NotImplementedError() + @abstractmethod + async def query_hybrid( + self, embedding: NDArray, query_string: str, k: int, score_threshold: float + ) -> QueryChunksResponse: + raise NotImplementedError() + @abstractmethod async def delete(self): raise NotImplementedError() @@ -246,9 +252,14 @@ class VectorDBWithIndex: mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) query_string = interleaved_content_as_str(query) + + # Calculate embeddings for both vector and hybrid modes + embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) + query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) + if mode == "keyword": return await self.index.query_keyword(query_string, k, score_threshold) + elif mode == "hybrid": + return await self.index.query_hybrid(query_vector, query_string, k, score_threshold) else: - embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) - query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) return await self.index.query_vector(query_vector, k, score_threshold) diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index 010a0ca42..f99de87f0 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -84,6 +84,23 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}" +@pytest.mark.asyncio +async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings): + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + # Create a query embedding that's similar to the first chunk + query_embedding = sample_embeddings[0] + query_string = "Sentence 5" + + response = await sqlite_vec_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 + ) + + assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}" + # Verify scores are in descending order (higher is better) + assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) + + @pytest.mark.asyncio async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings): # Re-initialize with a clean index @@ -141,3 +158,144 @@ def test_generate_chunk_id(): "bc744db3-1b25-0a9c-cdff-b6ba3df73c36", "f68df25d-d9aa-ab4d-5684-64a233add20d", ] + + +@pytest.mark.asyncio +async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings): + """Test hybrid search when keyword search returns no matches - should still return vector results.""" + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + # Use a non-existent keyword but a valid vector query + query_embedding = sample_embeddings[0] + query_string = "Sentence 499" + + # First verify keyword search returns no results + keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0) + assert len(keyword_response.chunks) == 0, "Keyword search should return no results" + + # Get hybrid results + response = await sqlite_vec_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 + ) + + # Should still get results from vector search + assert len(response.chunks) > 0, "Should get results from vector search even with no keyword matches" + # Verify scores are in descending order + assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) + + +@pytest.mark.asyncio +async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings): + """Test hybrid search with a high score threshold.""" + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + # Use a very high score threshold that no results will meet + query_embedding = sample_embeddings[0] + query_string = "Sentence 5" + + response = await sqlite_vec_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=3, + score_threshold=1000.0, # Very high threshold + ) + + # Should return no results due to high threshold + assert len(response.chunks) == 0 + + +@pytest.mark.asyncio +async def test_query_chunks_hybrid_different_embedding( + sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension +): + """Test hybrid search with a different embedding than the stored ones.""" + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + # Create a random embedding that's different from stored ones + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "Sentence 5" + + response = await sqlite_vec_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 + ) + + # Should still get results if keyword matches exist + assert len(response.chunks) > 0 + # Verify scores are in descending order + assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) + + +@pytest.mark.asyncio +async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings): + """Test that RRF properly combines rankings when documents appear in both search methods.""" + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + # Create a query embedding that's similar to the first chunk + query_embedding = sample_embeddings[0] + # Use a keyword that appears in multiple documents + query_string = "Sentence 5" + + response = await sqlite_vec_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=5, score_threshold=0.0 + ) + + # Verify we get results from both search methods + assert len(response.chunks) > 0 + # Verify scores are in descending order (RRF should maintain this) + assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) + + +@pytest.mark.asyncio +async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings): + """Test that we correctly rank documents that appear in both search methods.""" + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + # Create a query embedding that's similar to the first chunk + query_embedding = sample_embeddings[0] + # Use a keyword that appears in the first document + query_string = "Sentence 0 from document 0" + + # First get individual results to verify ranks + vector_response = await sqlite_vec_index.query_vector(query_embedding, k=5, score_threshold=0.0) + keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0) + + # Verify document-0 appears in both results + assert any(chunk.metadata["document_id"] == "document-0" for chunk in vector_response.chunks), ( + "document-0 not found in vector search results" + ) + assert any(chunk.metadata["document_id"] == "document-0" for chunk in keyword_response.chunks), ( + "document-0 not found in keyword search results" + ) + + # Now get hybrid results + response = await sqlite_vec_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=1, score_threshold=0.0 + ) + + # Verify document-0 is ranked first in hybrid results + assert len(response.chunks) == 1 + assert response.chunks[0].metadata["document_id"] == "document-0", "document-0 not ranked first in hybrid results" + + +@pytest.mark.asyncio +async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings): + """Test hybrid search with documents that appear in only one search method.""" + await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) + + # Create a query embedding that's similar to the first chunk + query_embedding = sample_embeddings[0] + # Use a keyword that appears in a different document + query_string = "Sentence 9 from document 2" + + response = await sqlite_vec_index.query_hybrid( + embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 + ) + + # Should get results from both search methods + assert len(response.chunks) > 0 + # Verify scores are in descending order + assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) + # Verify we get results from both the vector-similar document and keyword-matched document + doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks} + assert "document-0" in doc_ids # From vector search + assert "document-2" in doc_ids # From keyword search