mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
implement hybrid search
This commit is contained in:
parent
0108bb1aa5
commit
d0fa1d88f5
1 changed files with 71 additions and 2 deletions
|
@ -28,6 +28,8 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
RERANKER_TYPE_RRF,
|
||||||
|
RERANKER_TYPE_WEIGHTED,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
)
|
)
|
||||||
|
@ -37,6 +39,8 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
VERSION = "v3"
|
VERSION = "v3"
|
||||||
|
|
||||||
|
|
||||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||||
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
|
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
|
||||||
|
@ -229,10 +233,75 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
query_string: str,
|
query_string: str,
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
reranker_type: str,
|
reranker_type: str = RERANKER_TYPE_RRF,
|
||||||
reranker_params: dict[str, Any] | None = None,
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
"""
|
||||||
|
Hybrid search using Milvus's native multi-vector search capabilities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding: The query embedding vector
|
||||||
|
query_string: The text query for keyword search
|
||||||
|
k: Number of results to return
|
||||||
|
score_threshold: Minimum similarity score threshold
|
||||||
|
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||||
|
reranker_params: Parameters for the reranker
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryChunksResponse with combined results
|
||||||
|
"""
|
||||||
|
if reranker_params is None:
|
||||||
|
reranker_params = {}
|
||||||
|
|
||||||
|
# Create search requests for both vector and keyword search
|
||||||
|
search_requests = []
|
||||||
|
|
||||||
|
# Vector search request
|
||||||
|
search_requests.append(
|
||||||
|
{
|
||||||
|
"data": [embedding.tolist()],
|
||||||
|
"anns_field": "vector",
|
||||||
|
"param": {"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
|
||||||
|
"limit": k,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keyword search request (BM25)
|
||||||
|
search_requests.append(
|
||||||
|
{"data": [query_string], "anns_field": "sparse", "param": {"drop_ratio_search": 0.2}, "limit": k}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure reranker
|
||||||
|
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||||
|
alpha = reranker_params.get("alpha", 0.5)
|
||||||
|
ranker = {"strategy": "weighted", "params": {"weights": [1 - alpha, alpha]}}
|
||||||
|
else:
|
||||||
|
# Default to RRF
|
||||||
|
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||||
|
ranker = {"strategy": "rrf", "params": {"k": impact_factor}}
|
||||||
|
|
||||||
|
# Perform native Milvus hybrid search
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.hybrid_search,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
reqs=search_requests,
|
||||||
|
ranker=ranker,
|
||||||
|
limit=k,
|
||||||
|
output_fields=["chunk_content"],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for res in search_res[0]:
|
||||||
|
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(res["distance"])
|
||||||
|
|
||||||
|
# Filter by score threshold
|
||||||
|
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||||
|
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue