mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-17 02:18:13 +00:00
feat: Implement hybrid search in SQLite-vec
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
941f505eb0
commit
eab85a7121
13 changed files with 335 additions and 10 deletions
2
docs/_static/llama-stack-spec.html
vendored
2
docs/_static/llama-stack-spec.html
vendored
|
@ -13994,7 +13994,7 @@
|
||||||
},
|
},
|
||||||
"mode": {
|
"mode": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"."
|
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|
3
docs/_static/llama-stack-spec.yaml
vendored
3
docs/_static/llama-stack-spec.yaml
vendored
|
@ -9756,7 +9756,8 @@ components:
|
||||||
mode:
|
mode:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||||
|
"vector".
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
|
|
|
@ -66,6 +66,39 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
|
The SQLite-vec provider supports three search modes:
|
||||||
|
|
||||||
|
1. **Vector Search** (`mode="vector"`): Performs pure vector similarity search using the embeddings.
|
||||||
|
2. **Keyword Search** (`mode="keyword"`): Performs full-text search using SQLite's FTS5.
|
||||||
|
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword search for better results. First performs keyword search to get candidate matches, then applies vector similarity search on those candidates.
|
||||||
|
|
||||||
|
Example with hybrid search:
|
||||||
|
```python
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example with explicit vector search:
|
||||||
|
```python
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example with keyword search:
|
||||||
|
```python
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
## Supported Search Modes
|
## Supported Search Modes
|
||||||
|
|
||||||
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
||||||
|
|
|
@ -76,7 +76,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
|
|
|
@ -131,6 +131,15 @@ class FaissIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
|
|
|
@ -34,7 +34,8 @@ logger = logging.getLogger(__name__)
|
||||||
# Specifying search mode is dependent on the VectorIO provider.
|
# Specifying search mode is dependent on the VectorIO provider.
|
||||||
VECTOR_SEARCH = "vector"
|
VECTOR_SEARCH = "vector"
|
||||||
KEYWORD_SEARCH = "keyword"
|
KEYWORD_SEARCH = "keyword"
|
||||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
HYBRID_SEARCH = "hybrid"
|
||||||
|
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: list[float]) -> bytes:
|
def serialize_vector(vector: list[float]) -> bytes:
|
||||||
|
@ -255,8 +256,6 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
"""
|
"""
|
||||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||||
"""
|
"""
|
||||||
if query_string is None:
|
|
||||||
raise ValueError("query_string is required for keyword search.")
|
|
||||||
|
|
||||||
def _execute_query():
|
def _execute_query():
|
||||||
connection = _create_sqlite_connection(self.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
|
@ -294,6 +293,69 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Hybrid search using Reciprocal Rank Fusion (RRF) to combine vector and keyword search results.
|
||||||
|
RRF assigns scores based on the reciprocal of the rank position in each search method,
|
||||||
|
then combines these scores to get a final ranking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: The query embedding vector
|
||||||
|
query_string: The text query for keyword search
|
||||||
|
k: Number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with combined results
|
||||||
|
"""
|
||||||
|
# Get results from both search methods
|
||||||
|
vector_response = await self.query_vector(embedding, k * 2, score_threshold)
|
||||||
|
keyword_response = await self.query_keyword(query_string, k * 2, score_threshold)
|
||||||
|
|
||||||
|
# Create dictionaries to store ranks for each method
|
||||||
|
vector_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(vector_response.chunks)}
|
||||||
|
keyword_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(keyword_response.chunks)}
|
||||||
|
|
||||||
|
# Calculate RRF scores for all unique document IDs
|
||||||
|
all_ids = set(vector_ranks.keys()) | set(keyword_ranks.keys())
|
||||||
|
rrf_scores = {}
|
||||||
|
for doc_id in all_ids:
|
||||||
|
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||||
|
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||||
|
# RRF formula: score = 1/(k + r) where k is a constant and r is the rank
|
||||||
|
rrf_scores[doc_id] = (1.0 / (60 + vector_rank)) + (1.0 / (60 + keyword_rank))
|
||||||
|
|
||||||
|
# Sort by RRF score and get top k results
|
||||||
|
sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)[:k]
|
||||||
|
|
||||||
|
# Combine results maintaining RRF scores
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc_id in sorted_ids:
|
||||||
|
score = rrf_scores[doc_id]
|
||||||
|
if score >= score_threshold:
|
||||||
|
# Try to get from vector results first
|
||||||
|
for chunk in vector_response.chunks:
|
||||||
|
if chunk.metadata["document_id"] == doc_id:
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If not in vector results, get from keyword results
|
||||||
|
for chunk in keyword_response.chunks:
|
||||||
|
if chunk.metadata["document_id"] == doc_id:
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
break
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
|
@ -345,7 +407,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
vector_db_data = row[0]
|
vector_db_data = row[0]
|
||||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
index = await SQLiteVecIndex.create(
|
index = await SQLiteVecIndex.create(
|
||||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
vector_db.embedding_dimension,
|
||||||
|
self.config.db_path,
|
||||||
|
vector_db.identifier,
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
|
@ -371,7 +435,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
await asyncio.to_thread(_register_db)
|
await asyncio.to_thread(_register_db)
|
||||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
index = await SQLiteVecIndex.create(
|
||||||
|
vector_db.embedding_dimension,
|
||||||
|
self.config.db_path,
|
||||||
|
vector_db.identifier,
|
||||||
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||||
|
|
|
@ -105,6 +105,15 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||||
|
|
||||||
|
|
||||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -103,6 +103,15 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -128,6 +128,15 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
|
@ -112,6 +112,15 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Qdrant")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await self.client.delete_collection(collection_name=self.collection_name)
|
await self.client.delete_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
|
|
@ -92,6 +92,15 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(
|
class WeaviateVectorIOAdapter(
|
||||||
VectorIO,
|
VectorIO,
|
||||||
|
|
|
@ -202,6 +202,12 @@ class EmbeddingIndex(ABC):
|
||||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def query_hybrid(
|
||||||
|
self, embedding: NDArray, query_string: str, k: int, score_threshold: float
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -246,9 +252,14 @@ class VectorDBWithIndex:
|
||||||
mode = params.get("mode")
|
mode = params.get("mode")
|
||||||
score_threshold = params.get("score_threshold", 0.0)
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
query_string = interleaved_content_as_str(query)
|
query_string = interleaved_content_as_str(query)
|
||||||
|
|
||||||
|
# Calculate embeddings for both vector and hybrid modes
|
||||||
|
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||||
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
|
|
||||||
if mode == "keyword":
|
if mode == "keyword":
|
||||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||||
|
elif mode == "hybrid":
|
||||||
|
return await self.index.query_hybrid(query_vector, query_string, k, score_threshold)
|
||||||
else:
|
else:
|
||||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
|
||||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
|
||||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||||
|
|
|
@ -84,6 +84,23 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
|
||||||
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
|
||||||
|
# Verify scores are in descending order (higher is better)
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
# Re-initialize with a clean index
|
# Re-initialize with a clean index
|
||||||
|
@ -141,3 +158,144 @@ def test_generate_chunk_id():
|
||||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Use a non-existent keyword but a valid vector query
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 499"
|
||||||
|
|
||||||
|
# First verify keyword search returns no results
|
||||||
|
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
|
||||||
|
assert len(keyword_response.chunks) == 0, "Keyword search should return no results"
|
||||||
|
|
||||||
|
# Get hybrid results
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still get results from vector search
|
||||||
|
assert len(response.chunks) > 0, "Should get results from vector search even with no keyword matches"
|
||||||
|
# Verify scores are in descending order
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test hybrid search with a high score threshold."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Use a very high score threshold that no results will meet
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=1000.0, # Very high threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return no results due to high threshold
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_different_embedding(
|
||||||
|
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||||
|
):
|
||||||
|
"""Test hybrid search with a different embedding than the stored ones."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a random embedding that's different from stored ones
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still get results if keyword matches exist
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
# Verify scores are in descending order
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test that RRF properly combines rankings when documents appear in both search methods."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
# Use a keyword that appears in multiple documents
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=5, score_threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get results from both search methods
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
# Verify scores are in descending order (RRF should maintain this)
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test that we correctly rank documents that appear in both search methods."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
# Use a keyword that appears in the first document
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
|
||||||
|
# First get individual results to verify ranks
|
||||||
|
vector_response = await sqlite_vec_index.query_vector(query_embedding, k=5, score_threshold=0.0)
|
||||||
|
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
|
||||||
|
|
||||||
|
# Verify document-0 appears in both results
|
||||||
|
assert any(chunk.metadata["document_id"] == "document-0" for chunk in vector_response.chunks), (
|
||||||
|
"document-0 not found in vector search results"
|
||||||
|
)
|
||||||
|
assert any(chunk.metadata["document_id"] == "document-0" for chunk in keyword_response.chunks), (
|
||||||
|
"document-0 not found in keyword search results"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now get hybrid results
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=1, score_threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify document-0 is ranked first in hybrid results
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
assert response.chunks[0].metadata["document_id"] == "document-0", "document-0 not ranked first in hybrid results"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test hybrid search with documents that appear in only one search method."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
# Use a keyword that appears in a different document
|
||||||
|
query_string = "Sentence 9 from document 2"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get results from both search methods
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
# Verify scores are in descending order
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
# Verify we get results from both the vector-similar document and keyword-matched document
|
||||||
|
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
|
||||||
|
assert "document-0" in doc_ids # From vector search
|
||||||
|
assert "document-2" in doc_ids # From keyword search
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue