mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
implement hybrid search
This commit is contained in:
parent
0108bb1aa5
commit
d0fa1d88f5
1 changed files with 71 additions and 2 deletions
|
@ -28,6 +28,8 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -37,6 +39,8 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
|
||||
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:milvus:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:milvus:{VERSION}::"
|
||||
|
@ -229,10 +233,75 @@ class MilvusIndex(EmbeddingIndex):
|
|||
query_string: str,
|
||||
k: int,
|
||||
score_threshold: float,
|
||||
reranker_type: str,
|
||||
reranker_type: str = RERANKER_TYPE_RRF,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
"""
|
||||
Hybrid search using Milvus's native multi-vector search capabilities.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||
reranker_params: Parameters for the reranker
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Create search requests for both vector and keyword search
|
||||
search_requests = []
|
||||
|
||||
# Vector search request
|
||||
search_requests.append(
|
||||
{
|
||||
"data": [embedding.tolist()],
|
||||
"anns_field": "vector",
|
||||
"param": {"metric_type": "COSINE", "params": {"score_threshold": score_threshold}},
|
||||
"limit": k,
|
||||
}
|
||||
)
|
||||
|
||||
# Keyword search request (BM25)
|
||||
search_requests.append(
|
||||
{"data": [query_string], "anns_field": "sparse", "param": {"drop_ratio_search": 0.2}, "limit": k}
|
||||
)
|
||||
|
||||
# Configure reranker
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
ranker = {"strategy": "weighted", "params": {"weights": [1 - alpha, alpha]}}
|
||||
else:
|
||||
# Default to RRF
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
ranker = {"strategy": "rrf", "params": {"k": impact_factor}}
|
||||
|
||||
# Perform native Milvus hybrid search
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.hybrid_search,
|
||||
collection_name=self.collection_name,
|
||||
reqs=search_requests,
|
||||
ranker=ranker,
|
||||
limit=k,
|
||||
output_fields=["chunk_content"],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for res in search_res[0]:
|
||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||
chunks.append(chunk)
|
||||
scores.append(res["distance"])
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue