feat: Introduce weighted and rrf reranker implementations

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha Prasad Narsing 2025-06-04 15:59:44 -07:00
parent eab85a7121
commit 6ea5c10d48
14 changed files with 637 additions and 75 deletions

View file

@ -13995,6 +13995,10 @@
"mode": { "mode": {
"type": "string", "type": "string",
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
},
"ranker": {
"$ref": "#/components/schemas/Ranker",
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
} }
}, },
"additionalProperties": false, "additionalProperties": false,
@ -14024,6 +14028,69 @@
} }
} }
}, },
"RRFRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "rrf",
"default": "rrf",
"description": "The type of ranker, always \"rrf\""
},
"impact_factor": {
"type": "number",
"default": 60.0,
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009)."
}
},
"additionalProperties": false,
"required": [
"type",
"impact_factor"
],
"title": "RRFRanker",
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
},
"Ranker": {
"oneOf": [
{
"$ref": "#/components/schemas/RRFRanker"
},
{
"$ref": "#/components/schemas/WeightedRanker"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"rrf": "#/components/schemas/RRFRanker",
"weighted": "#/components/schemas/WeightedRanker"
}
}
},
"WeightedRanker": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "weighted",
"default": "weighted",
"description": "The type of ranker, always \"weighted\""
},
"alpha": {
"type": "number",
"default": 0.5,
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
}
},
"additionalProperties": false,
"required": [
"type",
"alpha"
],
"title": "WeightedRanker",
"description": "Weighted ranker configuration that combines vector and keyword scores."
},
"QueryRequest": { "QueryRequest": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -9758,6 +9758,11 @@ components:
description: >- description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector". "vector".
ranker:
$ref: '#/components/schemas/Ranker'
description: >-
Configuration for the ranker to use in hybrid search. Defaults to RRF
ranker.
additionalProperties: false additionalProperties: false
required: required:
- query_generator_config - query_generator_config
@ -9776,6 +9781,58 @@ components:
mapping: mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
RRFRanker:
type: object
properties:
type:
type: string
const: rrf
default: rrf
description: The type of ranker, always "rrf"
impact_factor:
type: number
default: 60.0
description: >-
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
results. Must be greater than 0. Default of 60 is from the original RRF
paper (Cormack et al., 2009).
additionalProperties: false
required:
- type
- impact_factor
title: RRFRanker
description: >-
Reciprocal Rank Fusion (RRF) ranker configuration.
Ranker:
oneOf:
- $ref: '#/components/schemas/RRFRanker'
- $ref: '#/components/schemas/WeightedRanker'
discriminator:
propertyName: type
mapping:
rrf: '#/components/schemas/RRFRanker'
weighted: '#/components/schemas/WeightedRanker'
WeightedRanker:
type: object
properties:
type:
type: string
const: weighted
default: weighted
description: The type of ranker, always "weighted"
alpha:
type: number
default: 0.5
description: >-
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
only use vector scores, values in between blend both scores.
additionalProperties: false
required:
- type
- alpha
title: WeightedRanker
description: >-
Weighted ranker configuration that combines vector and keyword scores.
QueryRequest: QueryRequest:
type: object type: object
properties: properties:

View file

@ -79,6 +79,30 @@ response = await vector_io.query_chunks(
query="your query here", query="your query here",
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7}, params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
) )
# Using RRF ranker
response = await vector_io.query_chunks(
vector_db_id="my_db",
query="your query here",
params={
"mode": "hybrid",
"max_chunks": 3,
"score_threshold": 0.7,
"ranker": {"type": "rrf", "impact_factor": 60.0},
},
)
# Using weighted ranker
response = await vector_io.query_chunks(
vector_db_id="my_db",
query="your query here",
params={
"mode": "hybrid",
"max_chunks": 3,
"score_threshold": 0.7,
"ranker": {"type": "weighted", "alpha": 0.7}, # 70% vector, 30% keyword
},
)
``` ```
Example with explicit vector search: Example with explicit vector search:
@ -101,23 +125,67 @@ response = await vector_io.query_chunks(
## Supported Search Modes ## Supported Search Modes
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes. The SQLite vector store supports three search modes:
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in 1. **Vector Search** (`mode="vector"`): Uses vector similarity to find relevant chunks
`RAGQueryConfig`. For example: 2. **Keyword Search** (`mode="keyword"`): Uses keyword matching to find relevant chunks
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword scores using a ranker
### Hybrid Search
Hybrid search combines the strengths of both vector and keyword search by:
- Computing vector similarity scores
- Computing keyword match scores
- Using a ranker to combine these scores
Two ranker types are supported:
1. **RRF (Reciprocal Rank Fusion)**:
- Combines ranks from both vector and keyword results
- Uses an impact factor (default: 60.0) to control the weight of higher-ranked results
- Good for balancing between vector and keyword results
- The default impact factor of 60.0 comes from the original RRF paper by Cormack et al. (2009) [^1], which found this value to provide optimal performance across various retrieval tasks
2. **Weighted**:
- Linearly combines normalized vector and keyword scores
- Uses an alpha parameter (0-1) to control the blend:
- alpha=0: Only use keyword scores
- alpha=1: Only use vector scores
- alpha=0.5: Equal weight to both (default)
Example using RAGQueryConfig with different search modes:
```python ```python
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
query_config = RAGQueryConfig(max_chunks=6, mode="vector") # Vector search
config = RAGQueryConfig(mode="vector", max_chunks=5)
results = client.tool_runtime.rag_tool.query( # Keyword search
vector_db_ids=[vector_db_id], config = RAGQueryConfig(mode="keyword", max_chunks=5)
content="what is torchtune",
query_config=query_config, # Hybrid search with custom RRF ranker
config = RAGQueryConfig(
mode="hybrid",
max_chunks=5,
ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
) )
# Hybrid search with weighted ranker
config = RAGQueryConfig(
mode="hybrid",
max_chunks=5,
ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
)
# Hybrid search with default RRF ranker
config = RAGQueryConfig(
mode="hybrid", max_chunks=5
) # Will use RRF with impact_factor=60.0
``` ```
Note: The ranker configuration is only used in hybrid mode. For vector or keyword modes, the ranker parameter is ignored.
## Installation ## Installation
You can install SQLite-Vec using pip: You can install SQLite-Vec using pip:
@ -129,3 +197,5 @@ pip install sqlite-vec
## Documentation ## Documentation
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general. See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
[^1]: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). [Reciprocal rank fusion outperforms condorcet and individual rank learning methods](https://dl.acm.org/doi/10.1145/1571941.1572114). In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (pp. 758-759).

View file

@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
class RRFRanker(BaseModel):
"""
Reciprocal Rank Fusion (RRF) ranker configuration.
:param type: The type of ranker, always "rrf"
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
"""
type: Literal["rrf"] = "rrf"
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
@json_schema_type
class WeightedRanker(BaseModel):
"""
Weighted ranker configuration that combines vector and keyword scores.
:param type: The type of ranker, always "weighted"
:param alpha: Weight factor between 0 and 1.
0 means only use keyword scores,
1 means only use vector scores,
values in between blend both scores.
"""
type: Literal["weighted"] = "weighted"
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
)
Ranker = Annotated[
RRFRanker | WeightedRanker,
Field(discriminator="type"),
]
register_schema(Ranker, name="Ranker")
@json_schema_type @json_schema_type
class RAGDocument(BaseModel): class RAGDocument(BaseModel):
""" """
@ -77,6 +119,7 @@ class RAGQueryConfig(BaseModel):
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n" Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
:param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector". :param mode: Search mode for retrievaleither "vector", "keyword", or "hybrid". Default "vector".
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
""" """
# This config defines how a query is generated using the messages # This config defines how a query is generated using the messages
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
max_chunks: int = 5 max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None mode: str | None = None
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template") @field_validator("chunk_template")
def validate_chunk_template(cls, v: str) -> str: def validate_chunk_template(cls, v: str) -> str:

View file

@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
query=query, query=query,
params={ params={
"max_chunks": query_config.max_chunks,
"mode": query_config.mode, "mode": query_config.mode,
"max_chunks": query_config.max_chunks,
"score_threshold": 0.0,
"ranker": query_config.ranker,
}, },
) )
for vector_db_id in vector_db_ids for vector_db_id in vector_db_ids

View file

@ -137,6 +137,8 @@ class FaissIndex(EmbeddingIndex):
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in FAISS") raise NotImplementedError("Hybrid search is not supported in FAISS")

View file

@ -27,7 +27,12 @@ from llama_stack.apis.vector_io import (
) )
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF,
RERANKER_TYPE_WEIGHTED,
EmbeddingIndex,
VectorDBWithIndex,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,6 +57,59 @@ def _create_sqlite_connection(db_path):
return connection return connection
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
"""Normalize scores to [0,1] range using min-max normalization."""
if not scores:
return {}
min_score = min(scores.values())
max_score = max(scores.values())
score_range = max_score - min_score
if score_range > 0:
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
return {doc_id: 1.0 for doc_id in scores}
def _weighted_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
alpha: float = 0.5,
) -> dict[str, float]:
"""ReRanker that uses weighted average of scores."""
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
normalized_vector_scores = _normalize_scores(vector_scores)
normalized_keyword_scores = _normalize_scores(keyword_scores)
return {
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
for doc_id in all_ids
}
def _rrf_rerank(
vector_scores: dict[str, float],
keyword_scores: dict[str, float],
impact_factor: float = 60.0,
) -> dict[str, float]:
"""ReRanker that uses Reciprocal Rank Fusion."""
# Convert scores to ranks
vector_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
}
keyword_ranks = {
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
}
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
rrf_scores = {}
for doc_id in all_ids:
vector_rank = vector_ranks.get(doc_id, float("inf"))
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
return rrf_scores
class SQLiteVecIndex(EmbeddingIndex): class SQLiteVecIndex(EmbeddingIndex):
""" """
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec. An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
@ -299,60 +357,72 @@ class SQLiteVecIndex(EmbeddingIndex):
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
reranker_type: str = RERANKER_TYPE_RRF,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
""" """
Hybrid search using Reciprocal Rank Fusion (RRF) to combine vector and keyword search results. Hybrid search using a configurable re-ranking strategy.
RRF assigns scores based on the reciprocal of the rank position in each search method,
then combines these scores to get a final ranking.
Args: Args:
embedding: The query embedding vector embedding: The query embedding vector
query_string: The text query for keyword search query_string: The text query for keyword search
k: Number of results to return k: Number of results to return
score_threshold: Minimum similarity score threshold score_threshold: Minimum similarity score threshold
reranker_type: Type of reranker to use ("rrf" or "weighted")
reranker_params: Parameters for the reranker
Returns: Returns:
QueryChunksResponse with combined results QueryChunksResponse with combined results
""" """
if reranker_params is None:
reranker_params = {}
# Get results from both search methods # Get results from both search methods
vector_response = await self.query_vector(embedding, k * 2, score_threshold) vector_response = await self.query_vector(embedding, k, score_threshold)
keyword_response = await self.query_keyword(query_string, k * 2, score_threshold) keyword_response = await self.query_keyword(query_string, k, score_threshold)
# Create dictionaries to store ranks for each method # Convert responses to score dictionaries using generate_chunk_id
vector_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(vector_response.chunks)} vector_scores = {
keyword_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(keyword_response.chunks)} generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
}
keyword_scores = {
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
}
# Calculate RRF scores for all unique document IDs # Combine scores using the specified reranker
all_ids = set(vector_ranks.keys()) | set(keyword_ranks.keys()) if reranker_type == RERANKER_TYPE_WEIGHTED:
rrf_scores = {} alpha = reranker_params.get("alpha", 0.5)
for doc_id in all_ids: combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
vector_rank = vector_ranks.get(doc_id, float("inf")) else:
keyword_rank = keyword_ranks.get(doc_id, float("inf")) # Default to RRF for None, RRF, or any unknown types
# RRF formula: score = 1/(k + r) where k is a constant and r is the rank impact_factor = reranker_params.get("impact_factor", 60.0)
rrf_scores[doc_id] = (1.0 / (60 + vector_rank)) + (1.0 / (60 + keyword_rank)) combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
# Sort by RRF score and get top k results # Sort by combined score and get top k results
sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)[:k] sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
top_k_items = sorted_items[:k]
# Combine results maintaining RRF scores # Filter by score threshold
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
# Create a map of chunk_id to chunk for both responses
chunk_map = {}
for c in vector_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
for c in keyword_response.chunks:
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
chunk_map[chunk_id] = c
# Use the map to look up chunks by their IDs
chunks = [] chunks = []
scores = [] scores = []
for doc_id in sorted_ids: for doc_id, score in filtered_items:
score = rrf_scores[doc_id] if doc_id in chunk_map:
if score >= score_threshold: chunks.append(chunk_map[doc_id])
# Try to get from vector results first scores.append(score)
for chunk in vector_response.chunks:
if chunk.metadata["document_id"] == doc_id:
chunks.append(chunk)
scores.append(score)
break
else:
# If not in vector results, get from keyword results
for chunk in keyword_response.chunks:
if chunk.metadata["document_id"] == doc_id:
chunks.append(chunk)
scores.append(score)
break
return QueryChunksResponse(chunks=chunks, scores=scores) return QueryChunksResponse(chunks=chunks, scores=scores)

View file

@ -111,6 +111,8 @@ class ChromaIndex(EmbeddingIndex):
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Chroma") raise NotImplementedError("Hybrid search is not supported in Chroma")

View file

@ -109,6 +109,8 @@ class MilvusIndex(EmbeddingIndex):
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus") raise NotImplementedError("Hybrid search is not supported in Milvus")

View file

@ -134,6 +134,8 @@ class PGVectorIndex(EmbeddingIndex):
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in PGVector") raise NotImplementedError("Hybrid search is not supported in PGVector")

View file

@ -118,6 +118,8 @@ class QdrantIndex(EmbeddingIndex):
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Qdrant") raise NotImplementedError("Hybrid search is not supported in Qdrant")

View file

@ -98,6 +98,8 @@ class WeaviateIndex(EmbeddingIndex):
query_string: str, query_string: str,
k: int, k: int,
score_threshold: float, score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Weaviate") raise NotImplementedError("Hybrid search is not supported in Weaviate")

View file

@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Constants for reranker types
RERANKER_TYPE_RRF = "rrf"
RERANKER_TYPE_WEIGHTED = "weighted"
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string
@ -204,7 +208,13 @@ class EmbeddingIndex(ABC):
@abstractmethod @abstractmethod
async def query_hybrid( async def query_hybrid(
self, embedding: NDArray, query_string: str, k: int, score_threshold: float self,
embedding: NDArray,
query_string: str,
k: int,
score_threshold: float,
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError() raise NotImplementedError()
@ -251,15 +261,29 @@ class VectorDBWithIndex:
k = params.get("max_chunks", 3) k = params.get("max_chunks", 3)
mode = params.get("mode") mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0) score_threshold = params.get("score_threshold", 0.0)
# Get ranker configuration
ranker = params.get("ranker")
if ranker is None:
# Default to RRF with impact_factor=60.0
reranker_type = RERANKER_TYPE_RRF
reranker_params = {"impact_factor": 60.0}
else:
reranker_type = ranker.type
reranker_params = (
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
)
query_string = interleaved_content_as_str(query) query_string = interleaved_content_as_str(query)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
# Calculate embeddings for both vector and hybrid modes # Calculate embeddings for both vector and hybrid modes
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
if mode == "hybrid":
if mode == "keyword": return await self.index.query_hybrid(
return await self.index.query_keyword(query_string, k, score_threshold) query_vector, query_string, k, score_threshold, reranker_type, reranker_params
elif mode == "hybrid": )
return await self.index.query_hybrid(query_vector, query_string, k, score_threshold)
else: else:
return await self.index.query_vector(query_vector, k, score_threshold) return await self.index.query_vector(query_vector, k, score_threshold)

View file

@ -93,7 +93,12 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
query_string = "Sentence 5" query_string = "Sentence 5"
response = await sqlite_vec_index.query_hybrid( response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 embedding=query_embedding,
query_string=query_string,
k=3,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
) )
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}" assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
@ -175,7 +180,12 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
# Get hybrid results # Get hybrid results
response = await sqlite_vec_index.query_hybrid( response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 embedding=query_embedding,
query_string=query_string,
k=3,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
) )
# Should still get results from vector search # Should still get results from vector search
@ -198,6 +208,8 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
query_string=query_string, query_string=query_string,
k=3, k=3,
score_threshold=1000.0, # Very high threshold score_threshold=1000.0, # Very high threshold
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
) )
# Should return no results due to high threshold # Should return no results due to high threshold
@ -216,7 +228,12 @@ async def test_query_chunks_hybrid_different_embedding(
query_string = "Sentence 5" query_string = "Sentence 5"
response = await sqlite_vec_index.query_hybrid( response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 embedding=query_embedding,
query_string=query_string,
k=3,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
) )
# Should still get results if keyword matches exist # Should still get results if keyword matches exist
@ -236,7 +253,12 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
query_string = "Sentence 5" query_string = "Sentence 5"
response = await sqlite_vec_index.query_hybrid( response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=5, score_threshold=0.0 embedding=query_embedding,
query_string=query_string,
k=5,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
) )
# Verify we get results from both search methods # Verify we get results from both search methods
@ -247,7 +269,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test that we correctly rank documents that appear in both search methods."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# Create a query embedding that's similar to the first chunk # Create a query embedding that's similar to the first chunk
@ -255,26 +276,43 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
# Use a keyword that appears in the first document # Use a keyword that appears in the first document
query_string = "Sentence 0 from document 0" query_string = "Sentence 0 from document 0"
# First get individual results to verify ranks # Test weighted re-ranking
vector_response = await sqlite_vec_index.query_vector(query_embedding, k=5, score_threshold=0.0)
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
# Verify document-0 appears in both results
assert any(chunk.metadata["document_id"] == "document-0" for chunk in vector_response.chunks), (
"document-0 not found in vector search results"
)
assert any(chunk.metadata["document_id"] == "document-0" for chunk in keyword_response.chunks), (
"document-0 not found in keyword search results"
)
# Now get hybrid results
response = await sqlite_vec_index.query_hybrid( response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=1, score_threshold=0.0 embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.5},
) )
# Verify document-0 is ranked first in hybrid results
assert len(response.chunks) == 1 assert len(response.chunks) == 1
assert response.chunks[0].metadata["document_id"] == "document-0", "document-0 not ranked first in hybrid results" # Score should be weighted average of normalized keyword score and vector score
assert response.scores[0] > 0.5 # Both scores should be high
# Test RRF re-ranking
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert len(response.chunks) == 1
# RRF score should be sum of reciprocal ranks
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # 1/(60+1) + 1/(60+1)
# Test default re-ranking (should be RRF)
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert len(response.chunks) == 1
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
@pytest.mark.asyncio @pytest.mark.asyncio
@ -288,7 +326,12 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
query_string = "Sentence 9 from document 2" query_string = "Sentence 9 from document 2"
response = await sqlite_vec_index.query_hybrid( response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0 embedding=query_embedding,
query_string=query_string,
k=3,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
) )
# Should get results from both search methods # Should get results from both search methods
@ -299,3 +342,176 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks} doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
assert "document-0" in doc_ids # From vector search assert "document-0" in doc_ids # From vector search
assert "document-2" in doc_ids # From keyword search assert "document-2" in doc_ids # From keyword search
@pytest.mark.asyncio
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
sqlite_vec_index, sample_chunks, sample_embeddings
):
"""Test WeightedReRanker with different alpha values."""
# Re-add data before each search to ensure test isolation
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = sample_embeddings[0]
query_string = "Sentence 0 from document 0"
# alpha=1.0 (should behave like pure keyword)
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 1.0},
)
assert len(response.chunks) > 0 # Should get at least one result
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
# alpha=0.0 (should behave like pure vector)
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.0},
)
assert len(response.chunks) > 0 # Should get at least one result
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# alpha=0.7 (should be a mix)
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert len(response.chunks) > 0 # Should get at least one result
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test RRFReRanker with different impact factors."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = sample_embeddings[0]
query_string = "Sentence 0 from document 0"
# impact_factor=10
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 10.0},
)
assert len(response.chunks) == 1
assert response.scores[0] == pytest.approx(2.0 / 11.0, rel=1e-6)
# impact_factor=100
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 100.0},
)
assert len(response.chunks) == 1
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
# No results from either search - use a completely different embedding and a nonzero threshold
query_embedding = np.ones_like(sample_embeddings[0]) * -1 # Very different from sample embeddings
query_string = "no_such_keyword_that_will_never_match"
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=3,
score_threshold=0.1, # Nonzero threshold to filter out low-similarity matches
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert len(response.chunks) == 0
# All results below threshold
query_embedding = sample_embeddings[0]
query_string = "Sentence 0 from document 0"
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=3,
score_threshold=1000.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert len(response.chunks) == 0
# Large k value
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=100,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
# Should not error, should return all available results
assert len(response.chunks) > 0
assert len(response.chunks) <= 100
@pytest.mark.asyncio
async def test_query_chunks_hybrid_tie_breaking(
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
):
"""Test tie-breaking and determinism when scores are equal."""
# Create two chunks with the same content and embedding
chunk1 = Chunk(content="identical", metadata={"document_id": "docA"})
chunk2 = Chunk(content="identical", metadata={"document_id": "docB"})
chunks = [chunk1, chunk2]
# Use the same embedding for both chunks to ensure equal scores
same_embedding = sample_embeddings[0]
embeddings = np.array([same_embedding, same_embedding])
# Clear existing data and recreate index
await sqlite_vec_index.delete()
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db")
sqlite_vec_index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
await sqlite_vec_index.add_chunks(chunks, embeddings)
# Query with the same embedding and content to ensure equal scores
query_embedding = same_embedding
query_string = "identical"
# Run multiple queries to verify determinism
responses = []
for _ in range(3):
response = await sqlite_vec_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
responses.append(response)
# Verify all responses are identical
first_response = responses[0]
for response in responses[1:]:
assert response.chunks == first_response.chunks
assert response.scores == first_response.scores
# Verify both chunks are returned with equal scores
assert len(first_response.chunks) == 2
assert first_response.scores[0] == first_response.scores[1]
assert {chunk.metadata["document_id"] for chunk in first_response.chunks} == {"docA", "docB"}