mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-16 09:58:10 +00:00
feat: Introduce weighted and rrf reranker implementations
Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
eab85a7121
commit
6ea5c10d48
14 changed files with 637 additions and 75 deletions
67
docs/_static/llama-stack-spec.html
vendored
67
docs/_static/llama-stack-spec.html
vendored
|
@ -13995,6 +13995,10 @@
|
||||||
"mode": {
|
"mode": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
||||||
|
},
|
||||||
|
"ranker": {
|
||||||
|
"$ref": "#/components/schemas/Ranker",
|
||||||
|
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -14024,6 +14028,69 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"RRFRanker": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "rrf",
|
||||||
|
"default": "rrf",
|
||||||
|
"description": "The type of ranker, always \"rrf\""
|
||||||
|
},
|
||||||
|
"impact_factor": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 60.0,
|
||||||
|
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"impact_factor"
|
||||||
|
],
|
||||||
|
"title": "RRFRanker",
|
||||||
|
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
|
||||||
|
},
|
||||||
|
"Ranker": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/RRFRanker"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/WeightedRanker"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"rrf": "#/components/schemas/RRFRanker",
|
||||||
|
"weighted": "#/components/schemas/WeightedRanker"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"WeightedRanker": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "weighted",
|
||||||
|
"default": "weighted",
|
||||||
|
"description": "The type of ranker, always \"weighted\""
|
||||||
|
},
|
||||||
|
"alpha": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 0.5,
|
||||||
|
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"alpha"
|
||||||
|
],
|
||||||
|
"title": "WeightedRanker",
|
||||||
|
"description": "Weighted ranker configuration that combines vector and keyword scores."
|
||||||
|
},
|
||||||
"QueryRequest": {
|
"QueryRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
57
docs/_static/llama-stack-spec.yaml
vendored
57
docs/_static/llama-stack-spec.yaml
vendored
|
@ -9758,6 +9758,11 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||||
"vector".
|
"vector".
|
||||||
|
ranker:
|
||||||
|
$ref: '#/components/schemas/Ranker'
|
||||||
|
description: >-
|
||||||
|
Configuration for the ranker to use in hybrid search. Defaults to RRF
|
||||||
|
ranker.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
|
@ -9776,6 +9781,58 @@ components:
|
||||||
mapping:
|
mapping:
|
||||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||||
|
RRFRanker:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: rrf
|
||||||
|
default: rrf
|
||||||
|
description: The type of ranker, always "rrf"
|
||||||
|
impact_factor:
|
||||||
|
type: number
|
||||||
|
default: 60.0
|
||||||
|
description: >-
|
||||||
|
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
|
||||||
|
results. Must be greater than 0. Default of 60 is from the original RRF
|
||||||
|
paper (Cormack et al., 2009).
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- impact_factor
|
||||||
|
title: RRFRanker
|
||||||
|
description: >-
|
||||||
|
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||||
|
Ranker:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/RRFRanker'
|
||||||
|
- $ref: '#/components/schemas/WeightedRanker'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
rrf: '#/components/schemas/RRFRanker'
|
||||||
|
weighted: '#/components/schemas/WeightedRanker'
|
||||||
|
WeightedRanker:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: weighted
|
||||||
|
default: weighted
|
||||||
|
description: The type of ranker, always "weighted"
|
||||||
|
alpha:
|
||||||
|
type: number
|
||||||
|
default: 0.5
|
||||||
|
description: >-
|
||||||
|
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
|
||||||
|
only use vector scores, values in between blend both scores.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- alpha
|
||||||
|
title: WeightedRanker
|
||||||
|
description: >-
|
||||||
|
Weighted ranker configuration that combines vector and keyword scores.
|
||||||
QueryRequest:
|
QueryRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -79,6 +79,30 @@ response = await vector_io.query_chunks(
|
||||||
query="your query here",
|
query="your query here",
|
||||||
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Using RRF ranker
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={
|
||||||
|
"mode": "hybrid",
|
||||||
|
"max_chunks": 3,
|
||||||
|
"score_threshold": 0.7,
|
||||||
|
"ranker": {"type": "rrf", "impact_factor": 60.0},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Using weighted ranker
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={
|
||||||
|
"mode": "hybrid",
|
||||||
|
"max_chunks": 3,
|
||||||
|
"score_threshold": 0.7,
|
||||||
|
"ranker": {"type": "weighted", "alpha": 0.7}, # 70% vector, 30% keyword
|
||||||
|
},
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Example with explicit vector search:
|
Example with explicit vector search:
|
||||||
|
@ -101,23 +125,67 @@ response = await vector_io.query_chunks(
|
||||||
|
|
||||||
## Supported Search Modes
|
## Supported Search Modes
|
||||||
|
|
||||||
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
The SQLite vector store supports three search modes:
|
||||||
|
|
||||||
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
1. **Vector Search** (`mode="vector"`): Uses vector similarity to find relevant chunks
|
||||||
`RAGQueryConfig`. For example:
|
2. **Keyword Search** (`mode="keyword"`): Uses keyword matching to find relevant chunks
|
||||||
|
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword scores using a ranker
|
||||||
|
|
||||||
|
### Hybrid Search
|
||||||
|
|
||||||
|
Hybrid search combines the strengths of both vector and keyword search by:
|
||||||
|
- Computing vector similarity scores
|
||||||
|
- Computing keyword match scores
|
||||||
|
- Using a ranker to combine these scores
|
||||||
|
|
||||||
|
Two ranker types are supported:
|
||||||
|
|
||||||
|
1. **RRF (Reciprocal Rank Fusion)**:
|
||||||
|
- Combines ranks from both vector and keyword results
|
||||||
|
- Uses an impact factor (default: 60.0) to control the weight of higher-ranked results
|
||||||
|
- Good for balancing between vector and keyword results
|
||||||
|
- The default impact factor of 60.0 comes from the original RRF paper by Cormack et al. (2009) [^1], which found this value to provide optimal performance across various retrieval tasks
|
||||||
|
|
||||||
|
2. **Weighted**:
|
||||||
|
- Linearly combines normalized vector and keyword scores
|
||||||
|
- Uses an alpha parameter (0-1) to control the blend:
|
||||||
|
- alpha=0: Only use keyword scores
|
||||||
|
- alpha=1: Only use vector scores
|
||||||
|
- alpha=0.5: Equal weight to both (default)
|
||||||
|
|
||||||
|
Example using RAGQueryConfig with different search modes:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
|
||||||
|
|
||||||
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
# Vector search
|
||||||
|
config = RAGQueryConfig(mode="vector", max_chunks=5)
|
||||||
|
|
||||||
results = client.tool_runtime.rag_tool.query(
|
# Keyword search
|
||||||
vector_db_ids=[vector_db_id],
|
config = RAGQueryConfig(mode="keyword", max_chunks=5)
|
||||||
content="what is torchtune",
|
|
||||||
query_config=query_config,
|
# Hybrid search with custom RRF ranker
|
||||||
|
config = RAGQueryConfig(
|
||||||
|
mode="hybrid",
|
||||||
|
max_chunks=5,
|
||||||
|
ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Hybrid search with weighted ranker
|
||||||
|
config = RAGQueryConfig(
|
||||||
|
mode="hybrid",
|
||||||
|
max_chunks=5,
|
||||||
|
ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hybrid search with default RRF ranker
|
||||||
|
config = RAGQueryConfig(
|
||||||
|
mode="hybrid", max_chunks=5
|
||||||
|
) # Will use RRF with impact_factor=60.0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Note: The ranker configuration is only used in hybrid mode. For vector or keyword modes, the ranker parameter is ignored.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
@ -129,3 +197,5 @@ pip install sqlite-vec
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
||||||
|
|
||||||
|
[^1]: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). [Reciprocal rank fusion outperforms condorcet and individual rank learning methods](https://dl.acm.org/doi/10.1145/1571941.1572114). In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (pp. 758-759).
|
||||||
|
|
|
@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RRFRanker(BaseModel):
|
||||||
|
"""
|
||||||
|
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||||
|
|
||||||
|
:param type: The type of ranker, always "rrf"
|
||||||
|
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||||
|
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["rrf"] = "rrf"
|
||||||
|
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WeightedRanker(BaseModel):
|
||||||
|
"""
|
||||||
|
Weighted ranker configuration that combines vector and keyword scores.
|
||||||
|
|
||||||
|
:param type: The type of ranker, always "weighted"
|
||||||
|
:param alpha: Weight factor between 0 and 1.
|
||||||
|
0 means only use keyword scores,
|
||||||
|
1 means only use vector scores,
|
||||||
|
values in between blend both scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["weighted"] = "weighted"
|
||||||
|
alpha: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Ranker = Annotated[
|
||||||
|
RRFRanker | WeightedRanker,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(Ranker, name="Ranker")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGDocument(BaseModel):
|
class RAGDocument(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -77,6 +119,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||||
|
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
|
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
mode: str | None = None
|
mode: str | None = None
|
||||||
|
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
def validate_chunk_template(cls, v: str) -> str:
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
|
|
@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": query_config.max_chunks,
|
|
||||||
"mode": query_config.mode,
|
"mode": query_config.mode,
|
||||||
|
"max_chunks": query_config.max_chunks,
|
||||||
|
"score_threshold": 0.0,
|
||||||
|
"ranker": query_config.ranker,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for vector_db_id in vector_db_ids
|
for vector_db_id in vector_db_ids
|
||||||
|
|
|
@ -137,6 +137,8 @@ class FaissIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,12 @@ from llama_stack.apis.vector_io import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
RERANKER_TYPE_RRF,
|
||||||
|
RERANKER_TYPE_WEIGHTED,
|
||||||
|
EmbeddingIndex,
|
||||||
|
VectorDBWithIndex,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -52,6 +57,59 @@ def _create_sqlite_connection(db_path):
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||||
|
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||||
|
if not scores:
|
||||||
|
return {}
|
||||||
|
min_score = min(scores.values())
|
||||||
|
max_score = max(scores.values())
|
||||||
|
score_range = max_score - min_score
|
||||||
|
if score_range > 0:
|
||||||
|
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||||
|
return {doc_id: 1.0 for doc_id in scores}
|
||||||
|
|
||||||
|
|
||||||
|
def _weighted_rerank(
|
||||||
|
vector_scores: dict[str, float],
|
||||||
|
keyword_scores: dict[str, float],
|
||||||
|
alpha: float = 0.5,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""ReRanker that uses weighted average of scores."""
|
||||||
|
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||||
|
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||||
|
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||||
|
|
||||||
|
return {
|
||||||
|
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||||
|
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||||
|
for doc_id in all_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _rrf_rerank(
|
||||||
|
vector_scores: dict[str, float],
|
||||||
|
keyword_scores: dict[str, float],
|
||||||
|
impact_factor: float = 60.0,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||||
|
# Convert scores to ranks
|
||||||
|
vector_ranks = {
|
||||||
|
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||||
|
}
|
||||||
|
keyword_ranks = {
|
||||||
|
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||||
|
}
|
||||||
|
|
||||||
|
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||||
|
rrf_scores = {}
|
||||||
|
for doc_id in all_ids:
|
||||||
|
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||||
|
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||||
|
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||||
|
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||||
|
return rrf_scores
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecIndex(EmbeddingIndex):
|
class SQLiteVecIndex(EmbeddingIndex):
|
||||||
"""
|
"""
|
||||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||||
|
@ -299,60 +357,72 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
|
reranker_type: str = RERANKER_TYPE_RRF,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
"""
|
"""
|
||||||
Hybrid search using Reciprocal Rank Fusion (RRF) to combine vector and keyword search results.
|
Hybrid search using a configurable re-ranking strategy.
|
||||||
RRF assigns scores based on the reciprocal of the rank position in each search method,
|
|
||||||
then combines these scores to get a final ranking.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embedding: The query embedding vector
|
embedding: The query embedding vector
|
||||||
query_string: The text query for keyword search
|
query_string: The text query for keyword search
|
||||||
k: Number of results to return
|
k: Number of results to return
|
||||||
score_threshold: Minimum similarity score threshold
|
score_threshold: Minimum similarity score threshold
|
||||||
|
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||||
|
reranker_params: Parameters for the reranker
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
QueryChunksResponse with combined results
|
QueryChunksResponse with combined results
|
||||||
"""
|
"""
|
||||||
|
if reranker_params is None:
|
||||||
|
reranker_params = {}
|
||||||
|
|
||||||
# Get results from both search methods
|
# Get results from both search methods
|
||||||
vector_response = await self.query_vector(embedding, k * 2, score_threshold)
|
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||||
keyword_response = await self.query_keyword(query_string, k * 2, score_threshold)
|
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||||
|
|
||||||
# Create dictionaries to store ranks for each method
|
# Convert responses to score dictionaries using generate_chunk_id
|
||||||
vector_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(vector_response.chunks)}
|
vector_scores = {
|
||||||
keyword_ranks = {chunk.metadata["document_id"]: i + 1 for i, chunk in enumerate(keyword_response.chunks)}
|
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||||
|
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
keyword_scores = {
|
||||||
|
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||||
|
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
|
||||||
# Calculate RRF scores for all unique document IDs
|
# Combine scores using the specified reranker
|
||||||
all_ids = set(vector_ranks.keys()) | set(keyword_ranks.keys())
|
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||||
rrf_scores = {}
|
alpha = reranker_params.get("alpha", 0.5)
|
||||||
for doc_id in all_ids:
|
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
else:
|
||||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
# Default to RRF for None, RRF, or any unknown types
|
||||||
# RRF formula: score = 1/(k + r) where k is a constant and r is the rank
|
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||||
rrf_scores[doc_id] = (1.0 / (60 + vector_rank)) + (1.0 / (60 + keyword_rank))
|
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||||
|
|
||||||
# Sort by RRF score and get top k results
|
# Sort by combined score and get top k results
|
||||||
sorted_ids = sorted(rrf_scores.keys(), key=lambda x: rrf_scores[x], reverse=True)[:k]
|
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
top_k_items = sorted_items[:k]
|
||||||
|
|
||||||
# Combine results maintaining RRF scores
|
# Filter by score threshold
|
||||||
|
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||||
|
|
||||||
|
# Create a map of chunk_id to chunk for both responses
|
||||||
|
chunk_map = {}
|
||||||
|
for c in vector_response.chunks:
|
||||||
|
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||||
|
chunk_map[chunk_id] = c
|
||||||
|
for c in keyword_response.chunks:
|
||||||
|
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||||
|
chunk_map[chunk_id] = c
|
||||||
|
|
||||||
|
# Use the map to look up chunks by their IDs
|
||||||
chunks = []
|
chunks = []
|
||||||
scores = []
|
scores = []
|
||||||
for doc_id in sorted_ids:
|
for doc_id, score in filtered_items:
|
||||||
score = rrf_scores[doc_id]
|
if doc_id in chunk_map:
|
||||||
if score >= score_threshold:
|
chunks.append(chunk_map[doc_id])
|
||||||
# Try to get from vector results first
|
scores.append(score)
|
||||||
for chunk in vector_response.chunks:
|
|
||||||
if chunk.metadata["document_id"] == doc_id:
|
|
||||||
chunks.append(chunk)
|
|
||||||
scores.append(score)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# If not in vector results, get from keyword results
|
|
||||||
for chunk in keyword_response.chunks:
|
|
||||||
if chunk.metadata["document_id"] == doc_id:
|
|
||||||
chunks.append(chunk)
|
|
||||||
scores.append(score)
|
|
||||||
break
|
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
|
@ -111,6 +111,8 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||||
|
|
||||||
|
|
|
@ -109,6 +109,8 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,8 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||||
|
|
||||||
|
|
|
@ -118,6 +118,8 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Qdrant")
|
raise NotImplementedError("Hybrid search is not supported in Qdrant")
|
||||||
|
|
||||||
|
|
|
@ -98,6 +98,8 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Constants for reranker types
|
||||||
|
RERANKER_TYPE_RRF = "rrf"
|
||||||
|
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
@ -204,7 +208,13 @@ class EmbeddingIndex(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self, embedding: NDArray, query_string: str, k: int, score_threshold: float
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -251,15 +261,29 @@ class VectorDBWithIndex:
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
mode = params.get("mode")
|
mode = params.get("mode")
|
||||||
score_threshold = params.get("score_threshold", 0.0)
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
|
||||||
|
# Get ranker configuration
|
||||||
|
ranker = params.get("ranker")
|
||||||
|
if ranker is None:
|
||||||
|
# Default to RRF with impact_factor=60.0
|
||||||
|
reranker_type = RERANKER_TYPE_RRF
|
||||||
|
reranker_params = {"impact_factor": 60.0}
|
||||||
|
else:
|
||||||
|
reranker_type = ranker.type
|
||||||
|
reranker_params = (
|
||||||
|
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||||
|
)
|
||||||
|
|
||||||
query_string = interleaved_content_as_str(query)
|
query_string = interleaved_content_as_str(query)
|
||||||
|
if mode == "keyword":
|
||||||
|
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||||
|
|
||||||
# Calculate embeddings for both vector and hybrid modes
|
# Calculate embeddings for both vector and hybrid modes
|
||||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
|
if mode == "hybrid":
|
||||||
if mode == "keyword":
|
return await self.index.query_hybrid(
|
||||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||||
elif mode == "hybrid":
|
)
|
||||||
return await self.index.query_hybrid(query_vector, query_string, k, score_threshold)
|
|
||||||
else:
|
else:
|
||||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||||
|
|
|
@ -93,7 +93,12 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
|
||||||
query_string = "Sentence 5"
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
response = await sqlite_vec_index.query_hybrid(
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
|
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
|
||||||
|
@ -175,7 +180,12 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
|
||||||
|
|
||||||
# Get hybrid results
|
# Get hybrid results
|
||||||
response = await sqlite_vec_index.query_hybrid(
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should still get results from vector search
|
# Should still get results from vector search
|
||||||
|
@ -198,6 +208,8 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
|
||||||
query_string=query_string,
|
query_string=query_string,
|
||||||
k=3,
|
k=3,
|
||||||
score_threshold=1000.0, # Very high threshold
|
score_threshold=1000.0, # Very high threshold
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should return no results due to high threshold
|
# Should return no results due to high threshold
|
||||||
|
@ -216,7 +228,12 @@ async def test_query_chunks_hybrid_different_embedding(
|
||||||
query_string = "Sentence 5"
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
response = await sqlite_vec_index.query_hybrid(
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should still get results if keyword matches exist
|
# Should still get results if keyword matches exist
|
||||||
|
@ -236,7 +253,12 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
|
||||||
query_string = "Sentence 5"
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
response = await sqlite_vec_index.query_hybrid(
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
embedding=query_embedding, query_string=query_string, k=5, score_threshold=0.0
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=5,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify we get results from both search methods
|
# Verify we get results from both search methods
|
||||||
|
@ -247,7 +269,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
"""Test that we correctly rank documents that appear in both search methods."""
|
|
||||||
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
# Create a query embedding that's similar to the first chunk
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
@ -255,26 +276,43 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
|
||||||
# Use a keyword that appears in the first document
|
# Use a keyword that appears in the first document
|
||||||
query_string = "Sentence 0 from document 0"
|
query_string = "Sentence 0 from document 0"
|
||||||
|
|
||||||
# First get individual results to verify ranks
|
# Test weighted re-ranking
|
||||||
vector_response = await sqlite_vec_index.query_vector(query_embedding, k=5, score_threshold=0.0)
|
|
||||||
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
|
|
||||||
|
|
||||||
# Verify document-0 appears in both results
|
|
||||||
assert any(chunk.metadata["document_id"] == "document-0" for chunk in vector_response.chunks), (
|
|
||||||
"document-0 not found in vector search results"
|
|
||||||
)
|
|
||||||
assert any(chunk.metadata["document_id"] == "document-0" for chunk in keyword_response.chunks), (
|
|
||||||
"document-0 not found in keyword search results"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now get hybrid results
|
|
||||||
response = await sqlite_vec_index.query_hybrid(
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
embedding=query_embedding, query_string=query_string, k=1, score_threshold=0.0
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.5},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify document-0 is ranked first in hybrid results
|
|
||||||
assert len(response.chunks) == 1
|
assert len(response.chunks) == 1
|
||||||
assert response.chunks[0].metadata["document_id"] == "document-0", "document-0 not ranked first in hybrid results"
|
# Score should be weighted average of normalized keyword score and vector score
|
||||||
|
assert response.scores[0] > 0.5 # Both scores should be high
|
||||||
|
|
||||||
|
# Test RRF re-ranking
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
# RRF score should be sum of reciprocal ranks
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # 1/(60+1) + 1/(60+1)
|
||||||
|
|
||||||
|
# Test default re-ranking (should be RRF)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -288,7 +326,12 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
|
||||||
query_string = "Sentence 9 from document 2"
|
query_string = "Sentence 9 from document 2"
|
||||||
|
|
||||||
response = await sqlite_vec_index.query_hybrid(
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
embedding=query_embedding, query_string=query_string, k=3, score_threshold=0.0
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should get results from both search methods
|
# Should get results from both search methods
|
||||||
|
@ -299,3 +342,176 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
|
||||||
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
|
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
|
||||||
assert "document-0" in doc_ids # From vector search
|
assert "document-0" in doc_ids # From vector search
|
||||||
assert "document-2" in doc_ids # From keyword search
|
assert "document-2" in doc_ids # From keyword search
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
||||||
|
sqlite_vec_index, sample_chunks, sample_embeddings
|
||||||
|
):
|
||||||
|
"""Test WeightedReRanker with different alpha values."""
|
||||||
|
# Re-add data before each search to ensure test isolation
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
|
||||||
|
# alpha=1.0 (should behave like pure keyword)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 1.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) > 0 # Should get at least one result
|
||||||
|
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||||
|
|
||||||
|
# alpha=0.0 (should behave like pure vector)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) > 0 # Should get at least one result
|
||||||
|
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||||
|
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
# alpha=0.7 (should be a mix)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.7},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) > 0 # Should get at least one result
|
||||||
|
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test RRFReRanker with different impact factors."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
|
||||||
|
# impact_factor=10
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 10.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 11.0, rel=1e-6)
|
||||||
|
|
||||||
|
# impact_factor=100
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 100.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# No results from either search - use a completely different embedding and a nonzero threshold
|
||||||
|
query_embedding = np.ones_like(sample_embeddings[0]) * -1 # Very different from sample embeddings
|
||||||
|
query_string = "no_such_keyword_that_will_never_match"
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.1, # Nonzero threshold to filter out low-similarity matches
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
# All results below threshold
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=1000.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
# Large k value
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=100,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
# Should not error, should return all available results
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert len(response.chunks) <= 100
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_tie_breaking(
|
||||||
|
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
||||||
|
):
|
||||||
|
"""Test tie-breaking and determinism when scores are equal."""
|
||||||
|
# Create two chunks with the same content and embedding
|
||||||
|
chunk1 = Chunk(content="identical", metadata={"document_id": "docA"})
|
||||||
|
chunk2 = Chunk(content="identical", metadata={"document_id": "docB"})
|
||||||
|
chunks = [chunk1, chunk2]
|
||||||
|
# Use the same embedding for both chunks to ensure equal scores
|
||||||
|
same_embedding = sample_embeddings[0]
|
||||||
|
embeddings = np.array([same_embedding, same_embedding])
|
||||||
|
|
||||||
|
# Clear existing data and recreate index
|
||||||
|
await sqlite_vec_index.delete()
|
||||||
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
|
db_path = str(temp_dir / "test_sqlite.db")
|
||||||
|
sqlite_vec_index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||||
|
await sqlite_vec_index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
# Query with the same embedding and content to ensure equal scores
|
||||||
|
query_embedding = same_embedding
|
||||||
|
query_string = "identical"
|
||||||
|
|
||||||
|
# Run multiple queries to verify determinism
|
||||||
|
responses = []
|
||||||
|
for _ in range(3):
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Verify all responses are identical
|
||||||
|
first_response = responses[0]
|
||||||
|
for response in responses[1:]:
|
||||||
|
assert response.chunks == first_response.chunks
|
||||||
|
assert response.scores == first_response.scores
|
||||||
|
|
||||||
|
# Verify both chunks are returned with equal scores
|
||||||
|
assert len(first_response.chunks) == 2
|
||||||
|
assert first_response.scores[0] == first_response.scores[1]
|
||||||
|
assert {chunk.metadata["document_id"] for chunk in first_response.chunks} == {"docA", "docB"}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue