feat: Implement hybrid search in SQLite-vec (#2312)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 4s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 15s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 24s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 22s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 41s
Test Llama Stack Build / generate-matrix (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 35s
Test External Providers / test-external-providers (venv) (push) Failing after 5s
Update ReadTheDocs / update-readthedocs (push) Failing after 5s
Unit Tests / unit-tests (3.11) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build (push) Failing after 7s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 18s
Unit Tests / unit-tests (3.10) (push) Failing after 17s
Pre-commit / pre-commit (push) Successful in 2m0s

# What does this PR do?
Add support for hybrid search mode in SQLite-vec provider, which
combines
keyword and vector search for better results. The implementation:

- Adds hybrid search mode as a new option alongside vector and keyword
search
- Implements query_hybrid method in SQLiteVecIndex that:
  - First performs keyword search to get candidate matches
  - Then applies vector similarity search on those candidates
- Updates documentation to reflect the new search mode

This change improves search quality by leveraging both semantic
similarity
and keyword matching, while maintaining backward compatibility with
existing
vector and keyword search modes.

## Test Plan
```
pytest tests/unit/providers/vector_io/test_sqlite_vec.py -v -s --tb=short
/Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:217: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
=============================================================================================== test session starts ===============================================================================================
platform darwin -- Python 3.10.16, pytest-8.3.5, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python
cachedir: .pytest_cache
metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.6-arm64-arm-64bit', 'Packages': {'pytest': '8.3.5', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'json-report': '1.5.0', 'timeout': '2.4.0', 'metadata': '3.1.1', 'anyio': '4.8.0', 'asyncio': '0.26.0', 'nbval': '0.11.0', 'cov': '6.1.1'}}
rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack
configfile: pyproject.toml
plugins: html-4.1.1, json-report-1.5.0, timeout-2.4.0, metadata-3.1.1, anyio-4.8.0, asyncio-0.26.0, nbval-0.11.0, cov-6.1.1
asyncio: mode=strict, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function
collected 10 items                                                                                                                                                                                                

tests/unit/providers/vector_io/test_sqlite_vec.py::test_add_chunks PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_full_text_search PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_full_text_search_k_greater_than_results PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_no_keyword_matches PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_score_threshold PASSED
tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_different_embedding PASSED
```

---------

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha 2025-06-13 12:54:06 -07:00 committed by GitHub
parent 941f505eb0
commit 2e8054bede
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 910 additions and 23 deletions

View file

@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
:param type: The type of ranker, always "rrf"
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
"""
type: Literal["rrf"] = "rrf"
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
:param type: The type of ranker, always "weighted"
:param alpha: Weight factor between 0 and 1.
0 means only use keyword scores,
1 means only use vector scores,
values in between blend both scores.
"""
type: Literal["weighted"] = "weighted"
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
)
Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type
class RAGDocument(BaseModel):
"""
@ -76,7 +118,8 @@ class RAGQueryConfig(BaseModel):
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector" or "keyword". Default "vector".
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
"""
# This config defines how a query is generated using the messages
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:

View file

@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id=vector_db_id,
query=query,
params={
"max_chunks": query_config.max_chunks,
"mode": query_config.mode,
"max_chunks": query_config.max_chunks,
"score_threshold": 0.0,
"ranker": query_config.ranker,
},
)
for vector_db_id in vector_db_ids

View file

@ -131,6 +131,17 @@ class FaissIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in FAISS")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in FAISS")
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:

View file

@ -27,14 +27,20 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
RERANKER_TYPE_WEIGHTED,
EmbeddingIndex,
VectorDBWithIndex,
)
logger = logging.getLogger(__name__)
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
def serialize_vector(vector: list[float]) -> bytes:
@ -51,6 +57,59 @@ def _create_sqlite_connection(db_path):
return connection
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""Normalize scores to [0,1] range using min-max normalization."""
if not scores:
return {}
min_score = min(scores.values())
max_score = max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return {doc_id: 1.0 for doc_id in scores}
def _weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""ReRanker that uses weighted average of scores."""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = _normalize_scores(vector_scores)
normalized_keyword_scores = _normalize_scores(keyword_scores)
return {
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
def _rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""ReRanker that uses Reciprocal Rank Fusion."""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -255,8 +314,6 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
"""
if query_string is None:
raise ValueError("query_string is required for keyword search.")
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
@ -294,6 +351,81 @@ class SQLiteVecIndex(EmbeddingIndex):
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str = RERANKER_TYPE_RRF,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
"""
Hybrid search using a configurable re-ranking strategy.
Args:
embedding: The query embedding vector
query_string: The text query for keyword search
k: Number of results to return
score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "weighted")
reranker_params: Parameters for the reranker
Returns:
QueryChunksResponse with combined results
"""
if reranker_params is None:
reranker_params = {}
# Get results from both search methods
vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Convert responses to score dictionaries using generate_chunk_id
vector_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Combine scores using the specified reranker
if reranker_type == RERANKER_TYPE_WEIGHTED:
alpha = reranker_params.get("alpha", 0.5)
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
else:
# Default to RRF for None, RRF, or any unknown types
impact_factor = reranker_params.get("impact_factor", 60.0)
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
# Sort by combined score and get top k results
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
top_k_items = sorted_items[:k]
# Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {}
for c in vector_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
for c in keyword_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
# Use the map to look up chunks by their IDs
chunks = []
scores = []
for doc_id, score in filtered_items:
if doc_id in chunk_map:
chunks.append(chunk_map[doc_id])
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
@ -345,7 +477,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
@ -371,7 +505,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:

View file

@ -105,6 +105,17 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Chroma")
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(

View file

@ -103,6 +103,17 @@ class MilvusIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus")
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(

View file

@ -128,6 +128,17 @@ class PGVectorIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in PGVector")
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

View file

@ -112,6 +112,17 @@ class QdrantIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Qdrant")
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)

View file

@ -92,6 +92,17 @@ class WeaviateIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Weaviate")
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Weaviate")
class WeaviateVectorIOAdapter(
VectorIO,

View file

@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
log = logging.getLogger(__name__)
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
@ -202,6 +206,18 @@ class EmbeddingIndex(ABC):
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def query_hybrid(
self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError()
@abstractmethod
async def delete(self):
raise NotImplementedError()
@ -245,10 +261,29 @@ class VectorDBWithIndex:
k = params.get("max_chunks", 3)
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
# Get ranker configuration
ranker = params.get("ranker")
if ranker is None:
# Default to RRF with impact_factor=60.0
reranker_type = RERANKER_TYPE_RRF
reranker_params = {"impact_factor": 60.0}
else:
reranker_type = ranker.type
reranker_params = (
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
)
query_string = interleaved_content_as_str(query)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
# Calculate embeddings for both vector and hybrid modes
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
if mode == "hybrid":
return await self.index.query_hybrid(
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
)
else:
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query_vector(query_vector, k, score_threshold)