diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index ce47f8ebb..e26725907 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -13994,7 +13994,11 @@
},
"mode": {
"type": "string",
- "description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"."
+ "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
+ },
+ "ranker": {
+ "$ref": "#/components/schemas/Ranker",
+ "description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
}
},
"additionalProperties": false,
@@ -14024,6 +14028,69 @@
}
}
},
+ "RRFRanker": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "rrf",
+ "default": "rrf",
+ "description": "The type of ranker, always \"rrf\""
+ },
+ "impact_factor": {
+ "type": "number",
+ "default": 60.0,
+ "description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009)."
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "impact_factor"
+ ],
+ "title": "RRFRanker",
+ "description": "Reciprocal Rank Fusion (RRF) ranker configuration."
+ },
+ "Ranker": {
+ "oneOf": [
+ {
+ "$ref": "#/components/schemas/RRFRanker"
+ },
+ {
+ "$ref": "#/components/schemas/WeightedRanker"
+ }
+ ],
+ "discriminator": {
+ "propertyName": "type",
+ "mapping": {
+ "rrf": "#/components/schemas/RRFRanker",
+ "weighted": "#/components/schemas/WeightedRanker"
+ }
+ }
+ },
+ "WeightedRanker": {
+ "type": "object",
+ "properties": {
+ "type": {
+ "type": "string",
+ "const": "weighted",
+ "default": "weighted",
+ "description": "The type of ranker, always \"weighted\""
+ },
+ "alpha": {
+ "type": "number",
+ "default": 0.5,
+ "description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
+ }
+ },
+ "additionalProperties": false,
+ "required": [
+ "type",
+ "alpha"
+ ],
+ "title": "WeightedRanker",
+ "description": "Weighted ranker configuration that combines vector and keyword scores."
+ },
"QueryRequest": {
"type": "object",
"properties": {
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 07a176b32..c4f356791 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -9756,7 +9756,13 @@ components:
mode:
type: string
description: >-
- Search mode for retrieval—either "vector" or "keyword". Default "vector".
+ Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
+ "vector".
+ ranker:
+ $ref: '#/components/schemas/Ranker'
+ description: >-
+ Configuration for the ranker to use in hybrid search. Defaults to RRF
+ ranker.
additionalProperties: false
required:
- query_generator_config
@@ -9775,6 +9781,58 @@ components:
mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
+ RRFRanker:
+ type: object
+ properties:
+ type:
+ type: string
+ const: rrf
+ default: rrf
+ description: The type of ranker, always "rrf"
+ impact_factor:
+ type: number
+ default: 60.0
+ description: >-
+ The impact factor for RRF scoring. Higher values give more weight to higher-ranked
+ results. Must be greater than 0. Default of 60 is from the original RRF
+ paper (Cormack et al., 2009).
+ additionalProperties: false
+ required:
+ - type
+ - impact_factor
+ title: RRFRanker
+ description: >-
+ Reciprocal Rank Fusion (RRF) ranker configuration.
+ Ranker:
+ oneOf:
+ - $ref: '#/components/schemas/RRFRanker'
+ - $ref: '#/components/schemas/WeightedRanker'
+ discriminator:
+ propertyName: type
+ mapping:
+ rrf: '#/components/schemas/RRFRanker'
+ weighted: '#/components/schemas/WeightedRanker'
+ WeightedRanker:
+ type: object
+ properties:
+ type:
+ type: string
+ const: weighted
+ default: weighted
+ description: The type of ranker, always "weighted"
+ alpha:
+ type: number
+ default: 0.5
+ description: >-
+ Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
+ only use vector scores, values in between blend both scores.
+ additionalProperties: false
+ required:
+ - type
+ - alpha
+ title: WeightedRanker
+ description: >-
+ Weighted ranker configuration that combines vector and keyword scores.
QueryRequest:
type: object
properties:
diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md
index 49ba659f7..3c7c4cbee 100644
--- a/docs/source/providers/vector_io/sqlite-vec.md
+++ b/docs/source/providers/vector_io/sqlite-vec.md
@@ -66,25 +66,126 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
2. Configure your Llama Stack project to use SQLite-Vec.
3. Start storing and querying vectors.
-## Supported Search Modes
+The SQLite-vec provider supports three search modes:
-The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
-
-When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
-`RAGQueryConfig`. For example:
+1. **Vector Search** (`mode="vector"`): Performs pure vector similarity search using the embeddings.
+2. **Keyword Search** (`mode="keyword"`): Performs full-text search using SQLite's FTS5.
+3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword search for better results. First performs keyword search to get candidate matches, then applies vector similarity search on those candidates.
+Example with hybrid search:
```python
-from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
+response = await vector_io.query_chunks(
+ vector_db_id="my_db",
+ query="your query here",
+ params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
+)
-query_config = RAGQueryConfig(max_chunks=6, mode="vector")
+# Using RRF ranker
+response = await vector_io.query_chunks(
+ vector_db_id="my_db",
+ query="your query here",
+ params={
+ "mode": "hybrid",
+ "max_chunks": 3,
+ "score_threshold": 0.7,
+ "ranker": {"type": "rrf", "impact_factor": 60.0},
+ },
+)
-results = client.tool_runtime.rag_tool.query(
- vector_db_ids=[vector_db_id],
- content="what is torchtune",
- query_config=query_config,
+# Using weighted ranker
+response = await vector_io.query_chunks(
+ vector_db_id="my_db",
+ query="your query here",
+ params={
+ "mode": "hybrid",
+ "max_chunks": 3,
+ "score_threshold": 0.7,
+ "ranker": {"type": "weighted", "alpha": 0.7}, # 70% vector, 30% keyword
+ },
)
```
+Example with explicit vector search:
+```python
+response = await vector_io.query_chunks(
+ vector_db_id="my_db",
+ query="your query here",
+ params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
+)
+```
+
+Example with keyword search:
+```python
+response = await vector_io.query_chunks(
+ vector_db_id="my_db",
+ query="your query here",
+ params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
+)
+```
+
+## Supported Search Modes
+
+The SQLite vector store supports three search modes:
+
+1. **Vector Search** (`mode="vector"`): Uses vector similarity to find relevant chunks
+2. **Keyword Search** (`mode="keyword"`): Uses keyword matching to find relevant chunks
+3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword scores using a ranker
+
+### Hybrid Search
+
+Hybrid search combines the strengths of both vector and keyword search by:
+- Computing vector similarity scores
+- Computing keyword match scores
+- Using a ranker to combine these scores
+
+Two ranker types are supported:
+
+1. **RRF (Reciprocal Rank Fusion)**:
+ - Combines ranks from both vector and keyword results
+ - Uses an impact factor (default: 60.0) to control the weight of higher-ranked results
+ - Good for balancing between vector and keyword results
+ - The default impact factor of 60.0 comes from the original RRF paper by Cormack et al. (2009) [^1], which found this value to provide optimal performance across various retrieval tasks
+
+2. **Weighted**:
+ - Linearly combines normalized vector and keyword scores
+ - Uses an alpha parameter (0-1) to control the blend:
+ - alpha=0: Only use keyword scores
+ - alpha=1: Only use vector scores
+ - alpha=0.5: Equal weight to both (default)
+
+Example using RAGQueryConfig with different search modes:
+
+```python
+from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
+
+# Vector search
+config = RAGQueryConfig(mode="vector", max_chunks=5)
+
+# Keyword search
+config = RAGQueryConfig(mode="keyword", max_chunks=5)
+
+# Hybrid search with custom RRF ranker
+config = RAGQueryConfig(
+ mode="hybrid",
+ max_chunks=5,
+ ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
+)
+
+# Hybrid search with weighted ranker
+config = RAGQueryConfig(
+ mode="hybrid",
+ max_chunks=5,
+ ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
+)
+
+# Hybrid search with default RRF ranker
+config = RAGQueryConfig(
+ mode="hybrid", max_chunks=5
+) # Will use RRF with impact_factor=60.0
+```
+
+Note: The ranker configuration is only used in hybrid mode. For vector or keyword modes, the ranker parameter is ignored.
+
## Installation
You can install SQLite-Vec using pip:
@@ -96,3 +197,5 @@ pip install sqlite-vec
## Documentation
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
+
+[^1]: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). [Reciprocal rank fusion outperforms condorcet and individual rank learning methods](https://dl.acm.org/doi/10.1145/1571941.1572114). In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (pp. 758-759).
diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py
index 1e3542f74..72f68b7cb 100644
--- a/llama_stack/apis/tools/rag_tool.py
+++ b/llama_stack/apis/tools/rag_tool.py
@@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
+@json_schema_type
+class RRFRanker(BaseModel):
+ """
+ Reciprocal Rank Fusion (RRF) ranker configuration.
+
+ :param type: The type of ranker, always "rrf"
+ :param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
+ Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
+ """
+
+ type: Literal["rrf"] = "rrf"
+ impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
+
+
+@json_schema_type
+class WeightedRanker(BaseModel):
+ """
+ Weighted ranker configuration that combines vector and keyword scores.
+
+ :param type: The type of ranker, always "weighted"
+ :param alpha: Weight factor between 0 and 1.
+ 0 means only use keyword scores,
+ 1 means only use vector scores,
+ values in between blend both scores.
+ """
+
+ type: Literal["weighted"] = "weighted"
+ alpha: float = Field(
+ default=0.5,
+ ge=0.0,
+ le=1.0,
+ description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
+ )
+
+
+Ranker = Annotated[
+ RRFRanker | WeightedRanker,
+ Field(discriminator="type"),
+]
+register_schema(Ranker, name="Ranker")
+
+
@json_schema_type
class RAGDocument(BaseModel):
"""
@@ -76,7 +118,8 @@ class RAGQueryConfig(BaseModel):
:param chunk_template: Template for formatting each retrieved chunk in the context.
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
- :param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
+ :param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
+ :param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
"""
# This config defines how a query is generated using the messages
@@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None
+ ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str:
diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py
index e15d067a7..7f4fe5dbd 100644
--- a/llama_stack/providers/inline/tool_runtime/rag/memory.py
+++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py
@@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id=vector_db_id,
query=query,
params={
- "max_chunks": query_config.max_chunks,
"mode": query_config.mode,
+ "max_chunks": query_config.max_chunks,
+ "score_threshold": 0.0,
+ "ranker": query_config.ranker,
},
)
for vector_db_id in vector_db_ids
diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py
index afb911726..a2f4417e0 100644
--- a/llama_stack/providers/inline/vector_io/faiss/faiss.py
+++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py
@@ -131,6 +131,17 @@ class FaissIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in FAISS")
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Hybrid search is not supported in FAISS")
+
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
index f69cf8a32..c6712882a 100644
--- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
+++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
@@ -27,14 +27,20 @@ from llama_stack.apis.vector_io import (
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
-from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
+from llama_stack.providers.utils.memory.vector_store import (
+ RERANKER_TYPE_RRF,
+ RERANKER_TYPE_WEIGHTED,
+ EmbeddingIndex,
+ VectorDBWithIndex,
+)
logger = logging.getLogger(__name__)
# Specifying search mode is dependent on the VectorIO provider.
VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
-SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
+HYBRID_SEARCH = "hybrid"
+SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
def serialize_vector(vector: list[float]) -> bytes:
@@ -51,6 +57,59 @@ def _create_sqlite_connection(db_path):
return connection
+def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
+ """Normalize scores to [0,1] range using min-max normalization."""
+ if not scores:
+ return {}
+ min_score = min(scores.values())
+ max_score = max(scores.values())
+ score_range = max_score - min_score
+ if score_range > 0:
+ return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
+ return {doc_id: 1.0 for doc_id in scores}
+
+
+def _weighted_rerank(
+ vector_scores: dict[str, float],
+ keyword_scores: dict[str, float],
+ alpha: float = 0.5,
+) -> dict[str, float]:
+ """ReRanker that uses weighted average of scores."""
+ all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
+ normalized_vector_scores = _normalize_scores(vector_scores)
+ normalized_keyword_scores = _normalize_scores(keyword_scores)
+
+ return {
+ doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
+ + ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
+ for doc_id in all_ids
+ }
+
+
+def _rrf_rerank(
+ vector_scores: dict[str, float],
+ keyword_scores: dict[str, float],
+ impact_factor: float = 60.0,
+) -> dict[str, float]:
+ """ReRanker that uses Reciprocal Rank Fusion."""
+ # Convert scores to ranks
+ vector_ranks = {
+ doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
+ }
+ keyword_ranks = {
+ doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
+ }
+
+ all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
+ rrf_scores = {}
+ for doc_id in all_ids:
+ vector_rank = vector_ranks.get(doc_id, float("inf"))
+ keyword_rank = keyword_ranks.get(doc_id, float("inf"))
+ # RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
+ rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
+ return rrf_scores
+
+
class SQLiteVecIndex(EmbeddingIndex):
"""
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@@ -255,8 +314,6 @@ class SQLiteVecIndex(EmbeddingIndex):
"""
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
"""
- if query_string is None:
- raise ValueError("query_string is required for keyword search.")
def _execute_query():
connection = _create_sqlite_connection(self.db_path)
@@ -294,6 +351,81 @@ class SQLiteVecIndex(EmbeddingIndex):
scores.append(score)
return QueryChunksResponse(chunks=chunks, scores=scores)
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str = RERANKER_TYPE_RRF,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ """
+ Hybrid search using a configurable re-ranking strategy.
+
+ Args:
+ embedding: The query embedding vector
+ query_string: The text query for keyword search
+ k: Number of results to return
+ score_threshold: Minimum similarity score threshold
+ reranker_type: Type of reranker to use ("rrf" or "weighted")
+ reranker_params: Parameters for the reranker
+
+ Returns:
+ QueryChunksResponse with combined results
+ """
+ if reranker_params is None:
+ reranker_params = {}
+
+ # Get results from both search methods
+ vector_response = await self.query_vector(embedding, k, score_threshold)
+ keyword_response = await self.query_keyword(query_string, k, score_threshold)
+
+ # Convert responses to score dictionaries using generate_chunk_id
+ vector_scores = {
+ generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
+ for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
+ }
+ keyword_scores = {
+ generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
+ for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
+ }
+
+ # Combine scores using the specified reranker
+ if reranker_type == RERANKER_TYPE_WEIGHTED:
+ alpha = reranker_params.get("alpha", 0.5)
+ combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
+ else:
+ # Default to RRF for None, RRF, or any unknown types
+ impact_factor = reranker_params.get("impact_factor", 60.0)
+ combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
+
+ # Sort by combined score and get top k results
+ sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
+ top_k_items = sorted_items[:k]
+
+ # Filter by score threshold
+ filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
+
+ # Create a map of chunk_id to chunk for both responses
+ chunk_map = {}
+ for c in vector_response.chunks:
+ chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
+ chunk_map[chunk_id] = c
+ for c in keyword_response.chunks:
+ chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
+ chunk_map[chunk_id] = c
+
+ # Use the map to look up chunks by their IDs
+ chunks = []
+ scores = []
+ for doc_id, score in filtered_items:
+ if doc_id in chunk_map:
+ chunks.append(chunk_map[doc_id])
+ scores.append(score)
+
+ return QueryChunksResponse(chunks=chunks, scores=scores)
+
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
@@ -345,7 +477,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
- vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
+ vector_db.embedding_dimension,
+ self.config.db_path,
+ vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
@@ -371,7 +505,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
connection.close()
await asyncio.to_thread(_register_db)
- index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
+ index = await SQLiteVecIndex.create(
+ vector_db.embedding_dimension,
+ self.config.db_path,
+ vector_db.identifier,
+ )
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:
diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py
index fee29cfd9..027cdcb11 100644
--- a/llama_stack/providers/remote/vector_io/chroma/chroma.py
+++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py
@@ -105,6 +105,17 @@ class ChromaIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Chroma")
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Hybrid search is not supported in Chroma")
+
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py
index 51c541c02..42ab4fa3e 100644
--- a/llama_stack/providers/remote/vector_io/milvus/milvus.py
+++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py
@@ -103,6 +103,17 @@ class MilvusIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus")
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Hybrid search is not supported in Milvus")
+
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
def __init__(
diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
index 7d58a49f3..1917af086 100644
--- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
+++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
@@ -128,6 +128,17 @@ class PGVectorIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in PGVector")
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Hybrid search is not supported in PGVector")
+
async def delete(self):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
index 1631a7a2a..fa7782f04 100644
--- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
+++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py
@@ -112,6 +112,17 @@ class QdrantIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Qdrant")
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Hybrid search is not supported in Qdrant")
+
async def delete(self):
await self.client.delete_collection(collection_name=self.collection_name)
diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
index 6f2027dad..c63dd70c6 100644
--- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
+++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
@@ -92,6 +92,17 @@ class WeaviateIndex(EmbeddingIndex):
) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Weaviate")
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError("Hybrid search is not supported in Weaviate")
+
class WeaviateVectorIOAdapter(
VectorIO,
diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py
index 2c0c7c8e9..a6e420feb 100644
--- a/llama_stack/providers/utils/memory/vector_store.py
+++ b/llama_stack/providers/utils/memory/vector_store.py
@@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
log = logging.getLogger(__name__)
+# Constants for reranker types
+RERANKER_TYPE_RRF = "rrf"
+RERANKER_TYPE_WEIGHTED = "weighted"
+
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
@@ -202,6 +206,18 @@ class EmbeddingIndex(ABC):
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
raise NotImplementedError()
+ @abstractmethod
+ async def query_hybrid(
+ self,
+ embedding: NDArray,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ reranker_type: str,
+ reranker_params: dict[str, Any] | None = None,
+ ) -> QueryChunksResponse:
+ raise NotImplementedError()
+
@abstractmethod
async def delete(self):
raise NotImplementedError()
@@ -245,10 +261,29 @@ class VectorDBWithIndex:
k = params.get("max_chunks", 3)
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
+
+ # Get ranker configuration
+ ranker = params.get("ranker")
+ if ranker is None:
+ # Default to RRF with impact_factor=60.0
+ reranker_type = RERANKER_TYPE_RRF
+ reranker_params = {"impact_factor": 60.0}
+ else:
+ reranker_type = ranker.type
+ reranker_params = (
+ {"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
+ )
+
query_string = interleaved_content_as_str(query)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
+
+ # Calculate embeddings for both vector and hybrid modes
+ embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
+ query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
+ if mode == "hybrid":
+ return await self.index.query_hybrid(
+ query_vector, query_string, k, score_threshold, reranker_type, reranker_params
+ )
else:
- embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
- query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query_vector(query_vector, k, score_threshold)
diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py
index 010a0ca42..6424b9e86 100644
--- a/tests/unit/providers/vector_io/test_sqlite_vec.py
+++ b/tests/unit/providers/vector_io/test_sqlite_vec.py
@@ -84,6 +84,28 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Create a query embedding that's similar to the first chunk
+ query_embedding = sample_embeddings[0]
+ query_string = "Sentence 5"
+
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=3,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+
+ assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
+ # Verify scores are in descending order (higher is better)
+ assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
+
+
@pytest.mark.asyncio
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
# Re-initialize with a clean index
@@ -141,3 +163,355 @@ def test_generate_chunk_id():
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
"f68df25d-d9aa-ab4d-5684-64a233add20d",
]
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
+ """Test hybrid search when keyword search returns no matches - should still return vector results."""
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Use a non-existent keyword but a valid vector query
+ query_embedding = sample_embeddings[0]
+ query_string = "Sentence 499"
+
+ # First verify keyword search returns no results
+ keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
+ assert len(keyword_response.chunks) == 0, "Keyword search should return no results"
+
+ # Get hybrid results
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=3,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+
+ # Should still get results from vector search
+ assert len(response.chunks) > 0, "Should get results from vector search even with no keyword matches"
+ # Verify scores are in descending order
+ assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
+ """Test hybrid search with a high score threshold."""
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Use a very high score threshold that no results will meet
+ query_embedding = sample_embeddings[0]
+ query_string = "Sentence 5"
+
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=3,
+ score_threshold=1000.0, # Very high threshold
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+
+ # Should return no results due to high threshold
+ assert len(response.chunks) == 0
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_different_embedding(
+ sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
+):
+ """Test hybrid search with a different embedding than the stored ones."""
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Create a random embedding that's different from stored ones
+ query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
+ query_string = "Sentence 5"
+
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=3,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+
+ # Should still get results if keyword matches exist
+ assert len(response.chunks) > 0
+ # Verify scores are in descending order
+ assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
+ """Test that RRF properly combines rankings when documents appear in both search methods."""
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Create a query embedding that's similar to the first chunk
+ query_embedding = sample_embeddings[0]
+ # Use a keyword that appears in multiple documents
+ query_string = "Sentence 5"
+
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=5,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+
+ # Verify we get results from both search methods
+ assert len(response.chunks) > 0
+ # Verify scores are in descending order (RRF should maintain this)
+ assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Create a query embedding that's similar to the first chunk
+ query_embedding = sample_embeddings[0]
+ # Use a keyword that appears in the first document
+ query_string = "Sentence 0 from document 0"
+
+ # Test weighted re-ranking
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="weighted",
+ reranker_params={"alpha": 0.5},
+ )
+ assert len(response.chunks) == 1
+ # Score should be weighted average of normalized keyword score and vector score
+ assert response.scores[0] > 0.5 # Both scores should be high
+
+ # Test RRF re-ranking
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+ assert len(response.chunks) == 1
+ # RRF score should be sum of reciprocal ranks
+ assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # 1/(60+1) + 1/(60+1)
+
+ # Test default re-ranking (should be RRF)
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+ assert len(response.chunks) == 1
+ assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
+ """Test hybrid search with documents that appear in only one search method."""
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Create a query embedding that's similar to the first chunk
+ query_embedding = sample_embeddings[0]
+ # Use a keyword that appears in a different document
+ query_string = "Sentence 9 from document 2"
+
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=3,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+
+ # Should get results from both search methods
+ assert len(response.chunks) > 0
+ # Verify scores are in descending order
+ assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
+ # Verify we get results from both the vector-similar document and keyword-matched document
+ doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
+ assert "document-0" in doc_ids # From vector search
+ assert "document-2" in doc_ids # From keyword search
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_weighted_reranker_parametrization(
+ sqlite_vec_index, sample_chunks, sample_embeddings
+):
+ """Test WeightedReRanker with different alpha values."""
+ # Re-add data before each search to ensure test isolation
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+ query_embedding = sample_embeddings[0]
+ query_string = "Sentence 0 from document 0"
+
+ # alpha=1.0 (should behave like pure keyword)
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="weighted",
+ reranker_params={"alpha": 1.0},
+ )
+ assert len(response.chunks) > 0 # Should get at least one result
+ assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
+
+ # alpha=0.0 (should behave like pure vector)
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="weighted",
+ reranker_params={"alpha": 0.0},
+ )
+ assert len(response.chunks) > 0 # Should get at least one result
+ assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
+
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+ # alpha=0.7 (should be a mix)
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="weighted",
+ reranker_params={"alpha": 0.7},
+ )
+ assert len(response.chunks) > 0 # Should get at least one result
+ assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
+ """Test RRFReRanker with different impact factors."""
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+ query_embedding = sample_embeddings[0]
+ query_string = "Sentence 0 from document 0"
+
+ # impact_factor=10
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 10.0},
+ )
+ assert len(response.chunks) == 1
+ assert response.scores[0] == pytest.approx(2.0 / 11.0, rel=1e-6)
+
+ # impact_factor=100
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=1,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 100.0},
+ )
+ assert len(response.chunks) == 1
+ assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
+ await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # No results from either search - use a completely different embedding and a nonzero threshold
+ query_embedding = np.ones_like(sample_embeddings[0]) * -1 # Very different from sample embeddings
+ query_string = "no_such_keyword_that_will_never_match"
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=3,
+ score_threshold=0.1, # Nonzero threshold to filter out low-similarity matches
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+ assert len(response.chunks) == 0
+
+ # All results below threshold
+ query_embedding = sample_embeddings[0]
+ query_string = "Sentence 0 from document 0"
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=3,
+ score_threshold=1000.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+ assert len(response.chunks) == 0
+
+ # Large k value
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=100,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+ # Should not error, should return all available results
+ assert len(response.chunks) > 0
+ assert len(response.chunks) <= 100
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_hybrid_tie_breaking(
+ sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
+):
+ """Test tie-breaking and determinism when scores are equal."""
+ # Create two chunks with the same content and embedding
+ chunk1 = Chunk(content="identical", metadata={"document_id": "docA"})
+ chunk2 = Chunk(content="identical", metadata={"document_id": "docB"})
+ chunks = [chunk1, chunk2]
+ # Use the same embedding for both chunks to ensure equal scores
+ same_embedding = sample_embeddings[0]
+ embeddings = np.array([same_embedding, same_embedding])
+
+ # Clear existing data and recreate index
+ await sqlite_vec_index.delete()
+ temp_dir = tmp_path_factory.getbasetemp()
+ db_path = str(temp_dir / "test_sqlite.db")
+ sqlite_vec_index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
+ await sqlite_vec_index.add_chunks(chunks, embeddings)
+
+ # Query with the same embedding and content to ensure equal scores
+ query_embedding = same_embedding
+ query_string = "identical"
+
+ # Run multiple queries to verify determinism
+ responses = []
+ for _ in range(3):
+ response = await sqlite_vec_index.query_hybrid(
+ embedding=query_embedding,
+ query_string=query_string,
+ k=2,
+ score_threshold=0.0,
+ reranker_type="rrf",
+ reranker_params={"impact_factor": 60.0},
+ )
+ responses.append(response)
+
+ # Verify all responses are identical
+ first_response = responses[0]
+ for response in responses[1:]:
+ assert response.chunks == first_response.chunks
+ assert response.scores == first_response.scores
+
+ # Verify both chunks are returned with equal scores
+ assert len(first_response.chunks) == 2
+ assert first_response.scores[0] == first_response.scores[1]
+ assert {chunk.metadata["document_id"] for chunk in first_response.chunks} == {"docA", "docB"}