mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: Implement hybrid search in SQLite-vec (#2312)
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 4s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 15s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 24s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 22s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 41s
Test Llama Stack Build / generate-matrix (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 35s
Test External Providers / test-external-providers (venv) (push) Failing after 5s
Update ReadTheDocs / update-readthedocs (push) Failing after 5s
Unit Tests / unit-tests (3.11) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build (push) Failing after 7s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 18s
Unit Tests / unit-tests (3.10) (push) Failing after 17s
Pre-commit / pre-commit (push) Successful in 2m0s
Some checks failed
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 4s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 5s
Integration Tests / test-matrix (http, 3.10, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 6s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 8s
Integration Tests / test-matrix (http, 3.12, vector_io) (push) Failing after 7s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 9s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (http, 3.11, vector_io) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 13s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, vector_io) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 8s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 15s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 16s
Integration Tests / test-matrix (library, 3.11, vector_io) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 25s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 24s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 22s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 6s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 9s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, vector_io) (push) Failing after 41s
Test Llama Stack Build / generate-matrix (push) Successful in 37s
Test Llama Stack Build / build-single-provider (push) Failing after 37s
Test Llama Stack Build / build-custom-container-distribution (push) Failing after 35s
Test External Providers / test-external-providers (venv) (push) Failing after 5s
Update ReadTheDocs / update-readthedocs (push) Failing after 5s
Unit Tests / unit-tests (3.11) (push) Failing after 6s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
Unit Tests / unit-tests (3.13) (push) Failing after 6s
Test Llama Stack Build / build (push) Failing after 7s
Test Llama Stack Build / build-ubi9-container-distribution (push) Failing after 18s
Unit Tests / unit-tests (3.10) (push) Failing after 17s
Pre-commit / pre-commit (push) Successful in 2m0s
# What does this PR do? Add support for hybrid search mode in SQLite-vec provider, which combines keyword and vector search for better results. The implementation: - Adds hybrid search mode as a new option alongside vector and keyword search - Implements query_hybrid method in SQLiteVecIndex that: - First performs keyword search to get candidate matches - Then applies vector similarity search on those candidates - Updates documentation to reflect the new search mode This change improves search quality by leveraging both semantic similarity and keyword matching, while maintaining backward compatibility with existing vector and keyword search modes. ## Test Plan ``` pytest tests/unit/providers/vector_io/test_sqlite_vec.py -v -s --tb=short /Users/vnarsing/miniconda3/envs/stack-client/lib/python3.10/site-packages/pytest_asyncio/plugin.py:217: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =============================================================================================== test session starts =============================================================================================== platform darwin -- Python 3.10.16, pytest-8.3.5, pluggy-1.5.0 -- /Users/vnarsing/miniconda3/envs/stack-client/bin/python cachedir: .pytest_cache metadata: {'Python': '3.10.16', 'Platform': 'macOS-14.7.6-arm64-arm-64bit', 'Packages': {'pytest': '8.3.5', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'json-report': '1.5.0', 'timeout': '2.4.0', 'metadata': '3.1.1', 'anyio': '4.8.0', 'asyncio': '0.26.0', 'nbval': '0.11.0', 'cov': '6.1.1'}} rootdir: /Users/vnarsing/go/src/github/meta-llama/llama-stack configfile: pyproject.toml plugins: html-4.1.1, json-report-1.5.0, timeout-2.4.0, metadata-3.1.1, anyio-4.8.0, asyncio-0.26.0, nbval-0.11.0, cov-6.1.1 asyncio: mode=strict, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 10 items tests/unit/providers/vector_io/test_sqlite_vec.py::test_add_chunks PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_vector PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_full_text_search PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_full_text_search_k_greater_than_results PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_chunk_id_conflict PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_generate_chunk_id PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_no_keyword_matches PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_score_threshold PASSED tests/unit/providers/vector_io/test_sqlite_vec.py::test_query_chunks_hybrid_different_embedding PASSED ``` --------- Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
941f505eb0
commit
2e8054bede
14 changed files with 910 additions and 23 deletions
|
@ -15,6 +15,48 @@ from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
|||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RRFRanker(BaseModel):
|
||||
"""
|
||||
Reciprocal Rank Fusion (RRF) ranker configuration.
|
||||
|
||||
:param type: The type of ranker, always "rrf"
|
||||
:param impact_factor: The impact factor for RRF scoring. Higher values give more weight to higher-ranked results.
|
||||
Must be greater than 0. Default of 60 is from the original RRF paper (Cormack et al., 2009).
|
||||
"""
|
||||
|
||||
type: Literal["rrf"] = "rrf"
|
||||
impact_factor: float = Field(default=60.0, gt=0.0) # default of 60 for optimal performance
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class WeightedRanker(BaseModel):
|
||||
"""
|
||||
Weighted ranker configuration that combines vector and keyword scores.
|
||||
|
||||
:param type: The type of ranker, always "weighted"
|
||||
:param alpha: Weight factor between 0 and 1.
|
||||
0 means only use keyword scores,
|
||||
1 means only use vector scores,
|
||||
values in between blend both scores.
|
||||
"""
|
||||
|
||||
type: Literal["weighted"] = "weighted"
|
||||
alpha: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Weight factor between 0 and 1. 0 means only keyword scores, 1 means only vector scores.",
|
||||
)
|
||||
|
||||
|
||||
Ranker = Annotated[
|
||||
RRFRanker | WeightedRanker,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(Ranker, name="Ranker")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class RAGDocument(BaseModel):
|
||||
"""
|
||||
|
@ -76,7 +118,8 @@ class RAGQueryConfig(BaseModel):
|
|||
:param chunk_template: Template for formatting each retrieved chunk in the context.
|
||||
Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict).
|
||||
Default: "Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n"
|
||||
:param mode: Search mode for retrieval—either "vector" or "keyword". Default "vector".
|
||||
:param mode: Search mode for retrieval—either "vector", "keyword", or "hybrid". Default "vector".
|
||||
:param ranker: Configuration for the ranker to use in hybrid search. Defaults to RRF ranker.
|
||||
"""
|
||||
|
||||
# This config defines how a query is generated using the messages
|
||||
|
@ -86,6 +129,7 @@ class RAGQueryConfig(BaseModel):
|
|||
max_chunks: int = 5
|
||||
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
|
||||
mode: str | None = None
|
||||
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
|
||||
|
||||
@field_validator("chunk_template")
|
||||
def validate_chunk_template(cls, v: str) -> str:
|
||||
|
|
|
@ -121,8 +121,10 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_id=vector_db_id,
|
||||
query=query,
|
||||
params={
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"mode": query_config.mode,
|
||||
"max_chunks": query_config.max_chunks,
|
||||
"score_threshold": 0.0,
|
||||
"ranker": query_config.ranker,
|
||||
},
|
||||
)
|
||||
for vector_db_id in vector_db_ids
|
||||
|
|
|
@ -131,6 +131,17 @@ class FaissIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in FAISS")
|
||||
|
||||
|
||||
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
|
||||
|
|
|
@ -27,14 +27,20 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Specifying search mode is dependent on the VectorIO provider.
|
||||
VECTOR_SEARCH = "vector"
|
||||
KEYWORD_SEARCH = "keyword"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||
HYBRID_SEARCH = "hybrid"
|
||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
|
||||
|
||||
|
||||
def serialize_vector(vector: list[float]) -> bytes:
|
||||
|
@ -51,6 +57,59 @@ def _create_sqlite_connection(db_path):
|
|||
return connection
|
||||
|
||||
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score = min(scores.values())
|
||||
max_score = max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return {doc_id: 1.0 for doc_id in scores}
|
||||
|
||||
|
||||
def _weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses weighted average of scores."""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||
|
||||
return {
|
||||
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
|
||||
def _rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
|
||||
class SQLiteVecIndex(EmbeddingIndex):
|
||||
"""
|
||||
An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec.
|
||||
|
@ -255,8 +314,6 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
"""
|
||||
Performs keyword-based search using SQLite FTS5 for relevance-ranked full-text search.
|
||||
"""
|
||||
if query_string is None:
|
||||
raise ValueError("query_string is required for keyword search.")
|
||||
|
||||
def _execute_query():
|
||||
connection = _create_sqlite_connection(self.db_path)
|
||||
|
@ -294,6 +351,81 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
scores.append(score)
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str = RERANKER_TYPE_RRF,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
"""
|
||||
Hybrid search using a configurable re-ranking strategy.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||
reranker_params: Parameters for the reranker
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Convert responses to score dictionaries using generate_chunk_id
|
||||
vector_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
generate_chunk_id(chunk.metadata["document_id"], str(chunk.content)): score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the specified reranker
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
|
||||
# Sort by combined score and get top k results
|
||||
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
top_k_items = sorted_items[:k]
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create a map of chunk_id to chunk for both responses
|
||||
chunk_map = {}
|
||||
for c in vector_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
for c in keyword_response.chunks:
|
||||
chunk_id = generate_chunk_id(c.metadata["document_id"], str(c.content))
|
||||
chunk_map[chunk_id] = c
|
||||
|
||||
# Use the map to look up chunks by their IDs
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
"""
|
||||
|
@ -345,7 +477,9 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
vector_db_data = row[0]
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension, self.config.db_path, vector_db.identifier
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
|
@ -371,7 +505,11 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
|||
connection.close()
|
||||
|
||||
await asyncio.to_thread(_register_db)
|
||||
index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.config.db_path, vector_db.identifier)
|
||||
index = await SQLiteVecIndex.create(
|
||||
vector_db.embedding_dimension,
|
||||
self.config.db_path,
|
||||
vector_db.identifier,
|
||||
)
|
||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
|
||||
async def list_vector_dbs(self) -> list[VectorDB]:
|
||||
|
|
|
@ -105,6 +105,17 @@ class ChromaIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Chroma")
|
||||
|
||||
|
||||
class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
|
@ -103,6 +103,17 @@ class MilvusIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
|
|
@ -128,6 +128,17 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||
|
|
|
@ -112,6 +112,17 @@ class QdrantIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Qdrant")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Qdrant")
|
||||
|
||||
async def delete(self):
|
||||
await self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
|
|
|
@ -92,6 +92,17 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in Weaviate")
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Weaviate")
|
||||
|
||||
|
||||
class WeaviateVectorIOAdapter(
|
||||
VectorIO,
|
||||
|
|
|
@ -32,6 +32,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Constants for reranker types
|
||||
RERANKER_TYPE_RRF = "rrf"
|
||||
RERANKER_TYPE_WEIGHTED = "weighted"
|
||||
|
||||
|
||||
def parse_pdf(data: bytes) -> str:
|
||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||
|
@ -202,6 +206,18 @@ class EmbeddingIndex(ABC):
|
|||
async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def query_hybrid(
|
||||
self,
|
||||
embedding: NDArray,
|
||||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def delete(self):
|
||||
raise NotImplementedError()
|
||||
|
@ -245,10 +261,29 @@ class VectorDBWithIndex:
|
|||
k = params.get("max_chunks", 3)
|
||||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
# Get ranker configuration
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
# Default to RRF with impact_factor=60.0
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
reranker_type = ranker.type
|
||||
reranker_params = (
|
||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||
)
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Calculate embeddings for both vector and hybrid modes
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
return await self.index.query_hybrid(
|
||||
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||
)
|
||||
else:
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
return await self.index.query_vector(query_vector, k, score_threshold)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue