feat: Implement hybrid search in SQLite-vec

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-05-29 14:32:54 -07:00
parent 941f505eb0
commit eab85a7121
13 changed files with 335 additions and 10 deletions

View file

@ -13994,7 +13994,7 @@
},
"mode": {
"type": "string",
"description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"."
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
}
},
"additionalProperties": false,

View file

@ -9756,7 +9756,8 @@ components:
mode:
type: string
description: >-
Search mode for retrieval—either "vector" or "keyword". Default "vector".
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector".
additionalProperties: false
required:
- query_generator_config

View file

@ -66,6 +66,39 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
2. Configure your Llama Stack project to use SQLite-Vec.
3. Start storing and querying vectors.
The SQLite-vec provider supports three search modes:
1. **Vector Search** (`mode="vector"`): Performs pure vector similarity search using the embeddings.
2. **Keyword Search** (`mode="keyword"`): Performs full-text search using SQLite's FTS5.
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword search for better results. First performs keyword search to get candidate matches, then applies vector similarity search on those candidates.
Example with hybrid search:
```python
response = await vector_io.query_chunks(
vector_db_id="my_db",
query="your query here",
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
)
```
Example with explicit vector search:
```python
response = await vector_io.query_chunks(
vector_db_id="my_db",
query="your query here",
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
)
```
Example with keyword search:
```python
response = await vector_io.query_chunks(
vector_db_id="my_db",
query="your query here",
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
)
```
## Supported Search Modes
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.

View file

@ -76,7 +76,7 @@ class RAGQueryConfig(BaseModel):
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector" or "keyword". Default "vector".
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
"""
# This config defines how a query is generated using the messages

View file

@ -131,6 +131,15 @@ class FaissIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in FAISS")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in FAISS")
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:

View file

@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
def serialize_vector(vector: list[float]) -> bytes:
@ -255,8 +256,6 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
"""
if query_string is None:
raise ValueError("query_string is required for keyword search.")
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
@ -294,6 +293,69 @@ class SQLiteVecIndex(EmbeddingIndex):
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Hybrid search using Reciprocal Rank Fusion (RRF) to combine vector and keyword search results.
RRF assigns scores based on the reciprocal of the rank position in each search method,
then combines these scores to get a final ranking.
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
Returns:
QueryChunksResponse with combined results
"""
# Get results from both search methods
vector_response = await self.query_vector(embedding, k * 2, score_threshold)
keyword_response = await self.query_keyword(query_string, k * 2, score_threshold)
# Create dictionaries to store ranks for each method
vector_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(vector_response.chunks)}
keyword_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(keyword_response.chunks)}
# Calculate RRF scores for all unique document IDs
all_ids = set(vector_ranks.keys()) | set(keyword_ranks.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is a constant and r is the rank
rrf_scores[doc_id] = (1.0 / (60 + vector_rank)) + (1.0 / (60 + keyword_rank))
# Sort by RRF score and get top k results
sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)[:k]
# Combine results maintaining RRF scores
chunks = []
scores = []
for doc_id in sorted_ids:
score = rrf_scores[doc_id]
if score >= score_threshold:
# Try to get from vector results first
for chunk in vector_response.chunks:
if chunk.metadata["document_id"] == doc_id:
chunks.append(chunk)
scores.append(score)
break
else:
# If not in vector results, get from keyword results
for chunk in keyword_response.chunks:
if chunk.metadata["document_id"] == doc_id:
chunks.append(chunk)
scores.append(score)
break
return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
@ -345,7 +407,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
@ -371,7 +435,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:

View file

@ -105,6 +105,15 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(

View file

@ -103,6 +103,15 @@ class MilvusIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus")
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(

View file

@ -128,6 +128,15 @@ class PGVectorIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in PGVector")
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View file

@ -112,6 +112,15 @@ class QdrantIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Qdrant")
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)

View file

@ -92,6 +92,15 @@ class WeaviateIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Weaviate")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Weaviate")
class WeaviateVectorIOAdapter(
VectorIO,

View file

@ -202,6 +202,12 @@ class EmbeddingIndex(ABC):
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def query_hybrid(
self, embedding: NDArray, query_string: str, k: int, score_threshold: float
) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def delete(self):
raise NotImplementedError()
@ -246,9 +252,14 @@ class VectorDBWithIndex:
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
query_string = interleaved_content_as_str(query)
# Calculate embeddings for both vector and hybrid modes
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
elif mode == "hybrid":
return await self.index.query_hybrid(query_vector, query_string, k, score_threshold)
else:
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query_vector(query_vector, k, score_threshold)

View file

@ -84,6 +84,23 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
@pytest.mark.asyncio
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Create a query embedding that's similar to the first chunk
query_embedding = sample_embeddings[0]
query_string = "Sentence 5"
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
)
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
# Verify scores are in descending order (higher is better)
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
# Re-initialize with a clean index
@ -141,3 +158,144 @@ def test_generate_chunk_id():
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]
@pytest.mark.asyncio
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Use a non-existent keyword but a valid vector query
query_embedding = sample_embeddings[0]
query_string = "Sentence 499"
# First verify keyword search returns no results
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
assert len(keyword_response.chunks) == 0, "Keyword search should return no results"
# Get hybrid results
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
)
# Should still get results from vector search
assert len(response.chunks) > 0, "Should get results from vector search even with no keyword matches"
# Verify scores are in descending order
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with a high score threshold."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Use a very high score threshold that no results will meet
query_embedding = sample_embeddings[0]
query_string = "Sentence 5"
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=3,
score_threshold=1000.0, # Very high threshold
)
# Should return no results due to high threshold
assert len(response.chunks) == 0
@pytest.mark.asyncio
async def test_query_chunks_hybrid_different_embedding(
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
):
"""Test hybrid search with a different embedding than the stored ones."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Create a random embedding that's different from stored ones
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "Sentence 5"
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
)
# Should still get results if keyword matches exist
assert len(response.chunks) > 0
# Verify scores are in descending order
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test that RRF properly combines rankings when documents appear in both search methods."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Create a query embedding that's similar to the first chunk
query_embedding = sample_embeddings[0]
# Use a keyword that appears in multiple documents
query_string = "Sentence 5"
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=5, score_threshold=0.0
)
# Verify we get results from both search methods
assert len(response.chunks) > 0
# Verify scores are in descending order (RRF should maintain this)
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test that we correctly rank documents that appear in both search methods."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Create a query embedding that's similar to the first chunk
query_embedding = sample_embeddings[0]
# Use a keyword that appears in the first document
query_string = "Sentence 0 from document 0"
# First get individual results to verify ranks
vector_response = await sqlite_vec_index.query_vector(query_embedding, k=5, score_threshold=0.0)
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
# Verify document-0 appears in both results
assert any(chunk.metadata["document_id"] == "document-0" for chunk in vector_response.chunks), (
"document-0 not found in vector search results"
)
assert any(chunk.metadata["document_id"] == "document-0" for chunk in keyword_response.chunks), (
"document-0 not found in keyword search results"
)
# Now get hybrid results
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=1, score_threshold=0.0
)
# Verify document-0 is ranked first in hybrid results
assert len(response.chunks) == 1
assert response.chunks[0].metadata["document_id"] == "document-0", "document-0 not ranked first in hybrid results"
@pytest.mark.asyncio
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with documents that appear in only one search method."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Create a query embedding that's similar to the first chunk
query_embedding = sample_embeddings[0]
# Use a keyword that appears in a different document
query_string = "Sentence 9 from document 2"
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
)
# Should get results from both search methods
assert len(response.chunks) > 0
# Verify scores are in descending order
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
# Verify we get results from both the vector-similar document and keyword-matched document
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
assert "document-0" in doc_ids # From vector search
assert "document-2" in doc_ids # From keyword search