mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Implement hybrid search in SQLite-vec (#2312)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 4s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 15s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 24s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 22s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 41s
Test Llama Stack Build / generate-matrix (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 35s
Test External Providers / test-external-providers (venv) (push) Failing after 5s
Update ReadTheDocs / update-readthedocs (push) Failing after 5s
Unit Tests / unit-tests (3.11) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build (push) Failing after 7s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 18s
Unit Tests / unit-tests (3.10) (push) Failing after 17s
Pre-commit / pre-commit (push) Successful in 2m0s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 4s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 15s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 24s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 22s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 41s
Test Llama Stack Build / generate-matrix (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 35s
Test External Providers / test-external-providers (venv) (push) Failing after 5s
Update ReadTheDocs / update-readthedocs (push) Failing after 5s
Unit Tests / unit-tests (3.11) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build (push) Failing after 7s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 18s
Unit Tests / unit-tests (3.10) (push) Failing after 17s
Pre-commit / pre-commit (push) Successful in 2m0s
# What does this PR do? Add support for hybrid search mode in SQLite-vec provider, which combines keyword and vector search for better results. The implementation: - Adds hybrid search mode as a new option alongside vector and keyword search - Implements query_hybrid method in SQLiteVecIndex that: - First performs keyword search to get candidate matches - Then applies vector similarity search on those candidates - Updates documentation to reflect the new search mode This change improves search quality by leveraging both semantic similarity and keyword matching, while maintaining backward compatibility with existing vector and keyword search modes. ## Test Plan ``` pytest tests/unit/providers/vector_io/test_sqlite_vec.py -v -s --tb=short /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:217: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =============================================================================================== test session starts =============================================================================================== platform darwin -- Python 3.10.16, pytest-8.3.5, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.6-arm64-arm-64bit', 'Packages': {'pytest': '8.3.5', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'json-report': '1.5.0', 'timeout': '2.4.0', 'metadata': '3.1.1', 'anyio': '4.8.0', 'asyncio': '0.26.0', 'nbval': '0.11.0', 'cov': '6.1.1'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, json-report-1.5.0, timeout-2.4.0, metadata-3.1.1, anyio-4.8.0, asyncio-0.26.0, nbval-0.11.0, cov-6.1.1 asyncio: mode=strict, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 10 items tests/unit/providers/vector_io/test_sqlite_vec.py::test_add_chunks PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_full_text_search PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_full_text_search_k_greater_than_results PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_no_keyword_matches PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_score_threshold PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_different_embedding PASSED ``` --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
941f505eb0
commit
2e8054bede
14 changed files with 910 additions and 23 deletions
69
docs/_static/llama-stack-spec.html
vendored
69
docs/_static/llama-stack-spec.html
vendored
|
@ -13994,7 +13994,11 @@
|
||||||
},
|
},
|
||||||
"mode": {
|
"mode": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Search mode for retrieval—either \"vector\" or \"keyword\". Default \"vector\"."
|
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
|
||||||
|
},
|
||||||
|
"ranker": {
|
||||||
|
"$ref": "#/components/schemas/Ranker",
|
||||||
|
"description": "Configuration for the ranker to use in hybrid search. Defaults to RRF ranker."
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -14024,6 +14028,69 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"RRFRanker": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "rrf",
|
||||||
|
"default": "rrf",
|
||||||
|
"description": "The type of ranker, always \"rrf\""
|
||||||
|
},
|
||||||
|
"impact_factor": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 60.0,
|
||||||
|
"description": "The impact factor for RRF scoring. Higher values give more weight to higher-ranked results. Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009)."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"impact_factor"
|
||||||
|
],
|
||||||
|
"title": "RRFRanker",
|
||||||
|
"description": "Reciprocal Rank Fusion (RRF) ranker configuration."
|
||||||
|
},
|
||||||
|
"Ranker": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/RRFRanker"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/WeightedRanker"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"discriminator": {
|
||||||
|
"propertyName": "type",
|
||||||
|
"mapping": {
|
||||||
|
"rrf": "#/components/schemas/RRFRanker",
|
||||||
|
"weighted": "#/components/schemas/WeightedRanker"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"WeightedRanker": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"const": "weighted",
|
||||||
|
"default": "weighted",
|
||||||
|
"description": "The type of ranker, always \"weighted\""
|
||||||
|
},
|
||||||
|
"alpha": {
|
||||||
|
"type": "number",
|
||||||
|
"default": 0.5,
|
||||||
|
"description": "Weight factor between 0 and 1. 0 means only use keyword scores, 1 means only use vector scores, values in between blend both scores."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"type",
|
||||||
|
"alpha"
|
||||||
|
],
|
||||||
|
"title": "WeightedRanker",
|
||||||
|
"description": "Weighted ranker configuration that combines vector and keyword scores."
|
||||||
|
},
|
||||||
"QueryRequest": {
|
"QueryRequest": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
|
|
60
docs/_static/llama-stack-spec.yaml
vendored
60
docs/_static/llama-stack-spec.yaml
vendored
|
@ -9756,7 +9756,13 @@ components:
|
||||||
mode:
|
mode:
|
||||||
type: string
|
type: string
|
||||||
description: >-
|
description: >-
|
||||||
Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
|
||||||
|
"vector".
|
||||||
|
ranker:
|
||||||
|
$ref: '#/components/schemas/Ranker'
|
||||||
|
description: >-
|
||||||
|
Configuration for the ranker to use in hybrid search. Defaults to RRF
|
||||||
|
ranker.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- query_generator_config
|
- query_generator_config
|
||||||
|
@ -9775,6 +9781,58 @@ components:
|
||||||
mapping:
|
mapping:
|
||||||
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
|
||||||
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
|
||||||
|
RRFRanker:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: rrf
|
||||||
|
default: rrf
|
||||||
|
description: The type of ranker, always "rrf"
|
||||||
|
impact_factor:
|
||||||
|
type: number
|
||||||
|
default: 60.0
|
||||||
|
description: >-
|
||||||
|
The impact factor for RRF scoring. Higher values give more weight to higher-ranked
|
||||||
|
results. Must be greater than 0. Default of 60 is from the original RRF
|
||||||
|
paper (Cormack et al., 2009).
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- impact_factor
|
||||||
|
title: RRFRanker
|
||||||
|
description: >-
|
||||||
|
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||||
|
Ranker:
|
||||||
|
oneOf:
|
||||||
|
- $ref: '#/components/schemas/RRFRanker'
|
||||||
|
- $ref: '#/components/schemas/WeightedRanker'
|
||||||
|
discriminator:
|
||||||
|
propertyName: type
|
||||||
|
mapping:
|
||||||
|
rrf: '#/components/schemas/RRFRanker'
|
||||||
|
weighted: '#/components/schemas/WeightedRanker'
|
||||||
|
WeightedRanker:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
const: weighted
|
||||||
|
default: weighted
|
||||||
|
description: The type of ranker, always "weighted"
|
||||||
|
alpha:
|
||||||
|
type: number
|
||||||
|
default: 0.5
|
||||||
|
description: >-
|
||||||
|
Weight factor between 0 and 1. 0 means only use keyword scores, 1 means
|
||||||
|
only use vector scores, values in between blend both scores.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
- alpha
|
||||||
|
title: WeightedRanker
|
||||||
|
description: >-
|
||||||
|
Weighted ranker configuration that combines vector and keyword scores.
|
||||||
QueryRequest:
|
QueryRequest:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
@ -66,25 +66,126 @@ To use sqlite-vec in your Llama Stack project, follow these steps:
|
||||||
2. Configure your Llama Stack project to use SQLite-Vec.
|
2. Configure your Llama Stack project to use SQLite-Vec.
|
||||||
3. Start storing and querying vectors.
|
3. Start storing and querying vectors.
|
||||||
|
|
||||||
## Supported Search Modes
|
The SQLite-vec provider supports three search modes:
|
||||||
|
|
||||||
The sqlite-vec provider supports both vector-based and keyword-based (full-text) search modes.
|
1. **Vector Search** (`mode="vector"`): Performs pure vector similarity search using the embeddings.
|
||||||
|
2. **Keyword Search** (`mode="keyword"`): Performs full-text search using SQLite's FTS5.
|
||||||
When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in
|
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword search for better results. First performs keyword search to get candidate matches, then applies vector similarity search on those candidates.
|
||||||
`RAGQueryConfig`. For example:
|
|
||||||
|
|
||||||
|
Example with hybrid search:
|
||||||
```python
|
```python
|
||||||
from llama_stack.apis.tool_runtime.rag import RAGQueryConfig
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={"mode": "hybrid", "max_chunks": 3, "score_threshold": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
query_config = RAGQueryConfig(max_chunks=6, mode="vector")
|
# Using RRF ranker
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={
|
||||||
|
"mode": "hybrid",
|
||||||
|
"max_chunks": 3,
|
||||||
|
"score_threshold": 0.7,
|
||||||
|
"ranker": {"type": "rrf", "impact_factor": 60.0},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
results = client.tool_runtime.rag_tool.query(
|
# Using weighted ranker
|
||||||
vector_db_ids=[vector_db_id],
|
response = await vector_io.query_chunks(
|
||||||
content="what is torchtune",
|
vector_db_id="my_db",
|
||||||
query_config=query_config,
|
query="your query here",
|
||||||
|
params={
|
||||||
|
"mode": "hybrid",
|
||||||
|
"max_chunks": 3,
|
||||||
|
"score_threshold": 0.7,
|
||||||
|
"ranker": {"type": "weighted", "alpha": 0.7}, # 70% vector, 30% keyword
|
||||||
|
},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Example with explicit vector search:
|
||||||
|
```python
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={"mode": "vector", "max_chunks": 3, "score_threshold": 0.7},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Example with keyword search:
|
||||||
|
```python
|
||||||
|
response = await vector_io.query_chunks(
|
||||||
|
vector_db_id="my_db",
|
||||||
|
query="your query here",
|
||||||
|
params={"mode": "keyword", "max_chunks": 3, "score_threshold": 0.7},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Search Modes
|
||||||
|
|
||||||
|
The SQLite vector store supports three search modes:
|
||||||
|
|
||||||
|
1. **Vector Search** (`mode="vector"`): Uses vector similarity to find relevant chunks
|
||||||
|
2. **Keyword Search** (`mode="keyword"`): Uses keyword matching to find relevant chunks
|
||||||
|
3. **Hybrid Search** (`mode="hybrid"`): Combines both vector and keyword scores using a ranker
|
||||||
|
|
||||||
|
### Hybrid Search
|
||||||
|
|
||||||
|
Hybrid search combines the strengths of both vector and keyword search by:
|
||||||
|
- Computing vector similarity scores
|
||||||
|
- Computing keyword match scores
|
||||||
|
- Using a ranker to combine these scores
|
||||||
|
|
||||||
|
Two ranker types are supported:
|
||||||
|
|
||||||
|
1. **RRF (Reciprocal Rank Fusion)**:
|
||||||
|
- Combines ranks from both vector and keyword results
|
||||||
|
- Uses an impact factor (default: 60.0) to control the weight of higher-ranked results
|
||||||
|
- Good for balancing between vector and keyword results
|
||||||
|
- The default impact factor of 60.0 comes from the original RRF paper by Cormack et al. (2009) [^1], which found this value to provide optimal performance across various retrieval tasks
|
||||||
|
|
||||||
|
2. **Weighted**:
|
||||||
|
- Linearly combines normalized vector and keyword scores
|
||||||
|
- Uses an alpha parameter (0-1) to control the blend:
|
||||||
|
- alpha=0: Only use keyword scores
|
||||||
|
- alpha=1: Only use vector scores
|
||||||
|
- alpha=0.5: Equal weight to both (default)
|
||||||
|
|
||||||
|
Example using RAGQueryConfig with different search modes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
|
||||||
|
|
||||||
|
# Vector search
|
||||||
|
config = RAGQueryConfig(mode="vector", max_chunks=5)
|
||||||
|
|
||||||
|
# Keyword search
|
||||||
|
config = RAGQueryConfig(mode="keyword", max_chunks=5)
|
||||||
|
|
||||||
|
# Hybrid search with custom RRF ranker
|
||||||
|
config = RAGQueryConfig(
|
||||||
|
mode="hybrid",
|
||||||
|
max_chunks=5,
|
||||||
|
ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hybrid search with weighted ranker
|
||||||
|
config = RAGQueryConfig(
|
||||||
|
mode="hybrid",
|
||||||
|
max_chunks=5,
|
||||||
|
ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hybrid search with default RRF ranker
|
||||||
|
config = RAGQueryConfig(
|
||||||
|
mode="hybrid", max_chunks=5
|
||||||
|
) # Will use RRF with impact_factor=60.0
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: The ranker configuration is only used in hybrid mode. For vector or keyword modes, the ranker parameter is ignored.
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
You can install SQLite-Vec using pip:
|
You can install SQLite-Vec using pip:
|
||||||
|
@ -96,3 +197,5 @@ pip install sqlite-vec
|
||||||
## Documentation
|
## Documentation
|
||||||
|
|
||||||
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general.
|
||||||
|
|
||||||
|
[^1]: Cormack, G. V., Clarke, C. L., & Buettcher, S. (2009). [Reciprocal rank fusion outperforms condorcet and individual rank learning methods](https://dl.acm.org/doi/10.1145/1571941.1572114). In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (pp. 758-759).
|
||||||
|
|
|
@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class RRFRanker(BaseModel):
|
||||||
|
"""
|
||||||
|
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||||
|
|
||||||
|
:param type: The type of ranker, always "rrf"
|
||||||
|
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||||
|
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["rrf"] = "rrf"
|
||||||
|
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WeightedRanker(BaseModel):
|
||||||
|
"""
|
||||||
|
Weighted ranker configuration that combines vector and keyword scores.
|
||||||
|
|
||||||
|
:param type: The type of ranker, always "weighted"
|
||||||
|
:param alpha: Weight factor between 0 and 1.
|
||||||
|
0 means only use keyword scores,
|
||||||
|
1 means only use vector scores,
|
||||||
|
values in between blend both scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Literal["weighted"] = "weighted"
|
||||||
|
alpha: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Ranker = Annotated[
|
||||||
|
RRFRanker | WeightedRanker,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(Ranker, name="Ranker")
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RAGDocument(BaseModel):
|
class RAGDocument(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -76,7 +118,8 @@ class RAGQueryConfig(BaseModel):
|
||||||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||||
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||||
|
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This config defines how a query is generated using the messages
|
# This config defines how a query is generated using the messages
|
||||||
|
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
|
||||||
max_chunks: int = 5
|
max_chunks: int = 5
|
||||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||||
mode: str | None = None
|
mode: str | None = None
|
||||||
|
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||||
|
|
||||||
@field_validator("chunk_template")
|
@field_validator("chunk_template")
|
||||||
def validate_chunk_template(cls, v: str) -> str:
|
def validate_chunk_template(cls, v: str) -> str:
|
||||||
|
|
|
@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
query=query,
|
query=query,
|
||||||
params={
|
params={
|
||||||
"max_chunks": query_config.max_chunks,
|
|
||||||
"mode": query_config.mode,
|
"mode": query_config.mode,
|
||||||
|
"max_chunks": query_config.max_chunks,
|
||||||
|
"score_threshold": 0.0,
|
||||||
|
"ranker": query_config.ranker,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for vector_db_id in vector_db_ids
|
for vector_db_id in vector_db_ids
|
||||||
|
|
|
@ -131,6 +131,17 @@ class FaissIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||||
|
|
|
@ -27,14 +27,20 @@ from llama_stack.apis.vector_io import (
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
RERANKER_TYPE_RRF,
|
||||||
|
RERANKER_TYPE_WEIGHTED,
|
||||||
|
EmbeddingIndex,
|
||||||
|
VectorDBWithIndex,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Specifying search mode is dependent on the VectorIO provider.
|
# Specifying search mode is dependent on the VectorIO provider.
|
||||||
VECTOR_SEARCH = "vector"
|
VECTOR_SEARCH = "vector"
|
||||||
KEYWORD_SEARCH = "keyword"
|
KEYWORD_SEARCH = "keyword"
|
||||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
HYBRID_SEARCH = "hybrid"
|
||||||
|
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: list[float]) -> bytes:
|
def serialize_vector(vector: list[float]) -> bytes:
|
||||||
|
@ -51,6 +57,59 @@ def _create_sqlite_connection(db_path):
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||||
|
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||||
|
if not scores:
|
||||||
|
return {}
|
||||||
|
min_score = min(scores.values())
|
||||||
|
max_score = max(scores.values())
|
||||||
|
score_range = max_score - min_score
|
||||||
|
if score_range > 0:
|
||||||
|
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||||
|
return {doc_id: 1.0 for doc_id in scores}
|
||||||
|
|
||||||
|
|
||||||
|
def _weighted_rerank(
|
||||||
|
vector_scores: dict[str, float],
|
||||||
|
keyword_scores: dict[str, float],
|
||||||
|
alpha: float = 0.5,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""ReRanker that uses weighted average of scores."""
|
||||||
|
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||||
|
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||||
|
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||||
|
|
||||||
|
return {
|
||||||
|
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||||
|
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||||
|
for doc_id in all_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _rrf_rerank(
|
||||||
|
vector_scores: dict[str, float],
|
||||||
|
keyword_scores: dict[str, float],
|
||||||
|
impact_factor: float = 60.0,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||||
|
# Convert scores to ranks
|
||||||
|
vector_ranks = {
|
||||||
|
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||||
|
}
|
||||||
|
keyword_ranks = {
|
||||||
|
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||||
|
}
|
||||||
|
|
||||||
|
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||||
|
rrf_scores = {}
|
||||||
|
for doc_id in all_ids:
|
||||||
|
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||||
|
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||||
|
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||||
|
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||||
|
return rrf_scores
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecIndex(EmbeddingIndex):
|
class SQLiteVecIndex(EmbeddingIndex):
|
||||||
"""
|
"""
|
||||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||||
|
@ -255,8 +314,6 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
"""
|
"""
|
||||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||||
"""
|
"""
|
||||||
if query_string is None:
|
|
||||||
raise ValueError("query_string is required for keyword search.")
|
|
||||||
|
|
||||||
def _execute_query():
|
def _execute_query():
|
||||||
connection = _create_sqlite_connection(self.db_path)
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
|
@ -294,6 +351,81 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str = RERANKER_TYPE_RRF,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Hybrid search using a configurable re-ranking strategy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: The query embedding vector
|
||||||
|
query_string: The text query for keyword search
|
||||||
|
k: Number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||||
|
reranker_params: Parameters for the reranker
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with combined results
|
||||||
|
"""
|
||||||
|
if reranker_params is None:
|
||||||
|
reranker_params = {}
|
||||||
|
|
||||||
|
# Get results from both search methods
|
||||||
|
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||||
|
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||||
|
|
||||||
|
# Convert responses to score dictionaries using generate_chunk_id
|
||||||
|
vector_scores = {
|
||||||
|
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||||
|
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
keyword_scores = {
|
||||||
|
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||||
|
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Combine scores using the specified reranker
|
||||||
|
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||||
|
alpha = reranker_params.get("alpha", 0.5)
|
||||||
|
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||||
|
else:
|
||||||
|
# Default to RRF for None, RRF, or any unknown types
|
||||||
|
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||||
|
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||||
|
|
||||||
|
# Sort by combined score and get top k results
|
||||||
|
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
top_k_items = sorted_items[:k]
|
||||||
|
|
||||||
|
# Filter by score threshold
|
||||||
|
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||||
|
|
||||||
|
# Create a map of chunk_id to chunk for both responses
|
||||||
|
chunk_map = {}
|
||||||
|
for c in vector_response.chunks:
|
||||||
|
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||||
|
chunk_map[chunk_id] = c
|
||||||
|
for c in keyword_response.chunks:
|
||||||
|
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||||
|
chunk_map[chunk_id] = c
|
||||||
|
|
||||||
|
# Use the map to look up chunks by their IDs
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for doc_id, score in filtered_items:
|
||||||
|
if doc_id in chunk_map:
|
||||||
|
chunks.append(chunk_map[doc_id])
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
|
@ -345,7 +477,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
vector_db_data = row[0]
|
vector_db_data = row[0]
|
||||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||||
index = await SQLiteVecIndex.create(
|
index = await SQLiteVecIndex.create(
|
||||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
vector_db.embedding_dimension,
|
||||||
|
self.config.db_path,
|
||||||
|
vector_db.identifier,
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
|
@ -371,7 +505,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
await asyncio.to_thread(_register_db)
|
await asyncio.to_thread(_register_db)
|
||||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
index = await SQLiteVecIndex.create(
|
||||||
|
vector_db.embedding_dimension,
|
||||||
|
self.config.db_path,
|
||||||
|
vector_db.identifier,
|
||||||
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||||
|
|
|
@ -105,6 +105,17 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||||
|
|
||||||
|
|
||||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -103,6 +103,17 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -128,6 +128,17 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
|
@ -112,6 +112,17 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Qdrant")
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
await self.client.delete_collection(collection_name=self.collection_name)
|
await self.client.delete_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
|
|
|
@ -92,6 +92,17 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOAdapter(
|
class WeaviateVectorIOAdapter(
|
||||||
VectorIO,
|
VectorIO,
|
||||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Constants for reranker types
|
||||||
|
RERANKER_TYPE_RRF = "rrf"
|
||||||
|
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
@ -202,6 +206,18 @@ class EmbeddingIndex(ABC):
|
||||||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def query_hybrid(
|
||||||
|
self,
|
||||||
|
embedding: NDArray,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
reranker_type: str,
|
||||||
|
reranker_params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -245,10 +261,29 @@ class VectorDBWithIndex:
|
||||||
k = params.get("max_chunks", 3)
|
k = params.get("max_chunks", 3)
|
||||||
mode = params.get("mode")
|
mode = params.get("mode")
|
||||||
score_threshold = params.get("score_threshold", 0.0)
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
|
||||||
|
# Get ranker configuration
|
||||||
|
ranker = params.get("ranker")
|
||||||
|
if ranker is None:
|
||||||
|
# Default to RRF with impact_factor=60.0
|
||||||
|
reranker_type = RERANKER_TYPE_RRF
|
||||||
|
reranker_params = {"impact_factor": 60.0}
|
||||||
|
else:
|
||||||
|
reranker_type = ranker.type
|
||||||
|
reranker_params = (
|
||||||
|
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||||
|
)
|
||||||
|
|
||||||
query_string = interleaved_content_as_str(query)
|
query_string = interleaved_content_as_str(query)
|
||||||
if mode == "keyword":
|
if mode == "keyword":
|
||||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||||
|
|
||||||
|
# Calculate embeddings for both vector and hybrid modes
|
||||||
|
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||||
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
|
if mode == "hybrid":
|
||||||
|
return await self.index.query_hybrid(
|
||||||
|
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
|
||||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
|
||||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||||
|
|
|
@ -84,6 +84,28 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
|
||||||
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.chunks) == 3, f"Expected 3 results, got {len(response.chunks)}"
|
||||||
|
# Verify scores are in descending order (higher is better)
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
# Re-initialize with a clean index
|
# Re-initialize with a clean index
|
||||||
|
@ -141,3 +163,355 @@ def test_generate_chunk_id():
|
||||||
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
"bc744db3-1b25-0a9c-cdff-b6ba3df73c36",
|
||||||
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
"f68df25d-d9aa-ab4d-5684-64a233add20d",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test hybrid search when keyword search returns no matches - should still return vector results."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Use a non-existent keyword but a valid vector query
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 499"
|
||||||
|
|
||||||
|
# First verify keyword search returns no results
|
||||||
|
keyword_response = await sqlite_vec_index.query_keyword(query_string, k=5, score_threshold=0.0)
|
||||||
|
assert len(keyword_response.chunks) == 0, "Keyword search should return no results"
|
||||||
|
|
||||||
|
# Get hybrid results
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still get results from vector search
|
||||||
|
assert len(response.chunks) > 0, "Should get results from vector search even with no keyword matches"
|
||||||
|
# Verify scores are in descending order
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test hybrid search with a high score threshold."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Use a very high score threshold that no results will meet
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=1000.0, # Very high threshold
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return no results due to high threshold
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_different_embedding(
|
||||||
|
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
|
||||||
|
):
|
||||||
|
"""Test hybrid search with a different embedding than the stored ones."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a random embedding that's different from stored ones
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still get results if keyword matches exist
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
# Verify scores are in descending order
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test that RRF properly combines rankings when documents appear in both search methods."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
# Use a keyword that appears in multiple documents
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=5,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get results from both search methods
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
# Verify scores are in descending order (RRF should maintain this)
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
# Use a keyword that appears in the first document
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
|
||||||
|
# Test weighted re-ranking
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.5},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
# Score should be weighted average of normalized keyword score and vector score
|
||||||
|
assert response.scores[0] > 0.5 # Both scores should be high
|
||||||
|
|
||||||
|
# Test RRF re-ranking
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
# RRF score should be sum of reciprocal ranks
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # 1/(60+1) + 1/(60+1)
|
||||||
|
|
||||||
|
# Test default re-ranking (should be RRF)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test hybrid search with documents that appear in only one search method."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Create a query embedding that's similar to the first chunk
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
# Use a keyword that appears in a different document
|
||||||
|
query_string = "Sentence 9 from document 2"
|
||||||
|
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should get results from both search methods
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
# Verify scores are in descending order
|
||||||
|
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
|
||||||
|
# Verify we get results from both the vector-similar document and keyword-matched document
|
||||||
|
doc_ids = {chunk.metadata["document_id"] for chunk in response.chunks}
|
||||||
|
assert "document-0" in doc_ids # From vector search
|
||||||
|
assert "document-2" in doc_ids # From keyword search
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_weighted_reranker_parametrization(
|
||||||
|
sqlite_vec_index, sample_chunks, sample_embeddings
|
||||||
|
):
|
||||||
|
"""Test WeightedReRanker with different alpha values."""
|
||||||
|
# Re-add data before each search to ensure test isolation
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
|
||||||
|
# alpha=1.0 (should behave like pure keyword)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 1.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) > 0 # Should get at least one result
|
||||||
|
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||||
|
|
||||||
|
# alpha=0.0 (should behave like pure vector)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) > 0 # Should get at least one result
|
||||||
|
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||||
|
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
# alpha=0.7 (should be a mix)
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.7},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) > 0 # Should get at least one result
|
||||||
|
assert any("document-0" in chunk.metadata["document_id"] for chunk in response.chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
"""Test RRFReRanker with different impact factors."""
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
|
||||||
|
# impact_factor=10
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 10.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 11.0, rel=1e-6)
|
||||||
|
|
||||||
|
# impact_factor=100
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 100.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
|
||||||
|
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# No results from either search - use a completely different embedding and a nonzero threshold
|
||||||
|
query_embedding = np.ones_like(sample_embeddings[0]) * -1 # Very different from sample embeddings
|
||||||
|
query_string = "no_such_keyword_that_will_never_match"
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=0.1, # Nonzero threshold to filter out low-similarity matches
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
# All results below threshold
|
||||||
|
query_embedding = sample_embeddings[0]
|
||||||
|
query_string = "Sentence 0 from document 0"
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=3,
|
||||||
|
score_threshold=1000.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
assert len(response.chunks) == 0
|
||||||
|
|
||||||
|
# Large k value
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=100,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
# Should not error, should return all available results
|
||||||
|
assert len(response.chunks) > 0
|
||||||
|
assert len(response.chunks) <= 100
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_hybrid_tie_breaking(
|
||||||
|
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
||||||
|
):
|
||||||
|
"""Test tie-breaking and determinism when scores are equal."""
|
||||||
|
# Create two chunks with the same content and embedding
|
||||||
|
chunk1 = Chunk(content="identical", metadata={"document_id": "docA"})
|
||||||
|
chunk2 = Chunk(content="identical", metadata={"document_id": "docB"})
|
||||||
|
chunks = [chunk1, chunk2]
|
||||||
|
# Use the same embedding for both chunks to ensure equal scores
|
||||||
|
same_embedding = sample_embeddings[0]
|
||||||
|
embeddings = np.array([same_embedding, same_embedding])
|
||||||
|
|
||||||
|
# Clear existing data and recreate index
|
||||||
|
await sqlite_vec_index.delete()
|
||||||
|
temp_dir = tmp_path_factory.getbasetemp()
|
||||||
|
db_path = str(temp_dir / "test_sqlite.db")
|
||||||
|
sqlite_vec_index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank")
|
||||||
|
await sqlite_vec_index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
# Query with the same embedding and content to ensure equal scores
|
||||||
|
query_embedding = same_embedding
|
||||||
|
query_string = "identical"
|
||||||
|
|
||||||
|
# Run multiple queries to verify determinism
|
||||||
|
responses = []
|
||||||
|
for _ in range(3):
|
||||||
|
response = await sqlite_vec_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
responses.append(response)
|
||||||
|
|
||||||
|
# Verify all responses are identical
|
||||||
|
first_response = responses[0]
|
||||||
|
for response in responses[1:]:
|
||||||
|
assert response.chunks == first_response.chunks
|
||||||
|
assert response.scores == first_response.scores
|
||||||
|
|
||||||
|
# Verify both chunks are returned with equal scores
|
||||||
|
assert len(first_response.chunks) == 2
|
||||||
|
assert first_response.scores[0] == first_response.scores[1]
|
||||||
|
assert {chunk.metadata["document_id"] for chunk in first_response.chunks} == {"docA", "docB"}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue