Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-08-07 19:57:56 +09:00 committed by GitHub
commit 4809b3352e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 204 additions and 9 deletions

View file

@ -10,7 +10,7 @@ import os
from typing import Any
from numpy.typing import NDArray
from pymilvus import DataType, Function, FunctionType, MilvusClient
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
EmbeddingIndex,
VectorDBWithIndex,
)
@ -238,7 +239,53 @@ class MilvusIndex(EmbeddingIndex):
reranker_type: str,
reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus")
"""
Hybrid search using Milvus's native hybrid search capabilities.
This implementation uses Milvus's hybrid_search method which combines
vector search and BM25 search with configurable reranking strategies.
"""
search_requests = []
# nprobe: Controls search accuracy vs performance trade-off
# 10 balances these trade-offs for RAG applications
search_requests.append(
AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k)
)
# drop_ratio_search: Filters low-importance terms to improve search performance
# 0.2 balances noise reduction with recall
search_requests.append(
AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k)
)
if reranker_type == RERANKER_TYPE_WEIGHTED:
alpha = (reranker_params or {}).get("alpha", 0.5)
rerank = WeightedRanker(alpha, 1 - alpha)
else:
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread(
self.client.hybrid_search,
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
limit=k,
output_fields=["chunk_content"],
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"])
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the Milvus collection."""

View file

@ -302,23 +302,25 @@ class VectorDBWithIndex:
mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0)
# Get ranker configuration
ranker = params.get("ranker")
if ranker is None:
# Default to RRF with impact_factor=60.0
reranker_type = RERANKER_TYPE_RRF
reranker_params = {"impact_factor": 60.0}
else:
reranker_type = ranker.type
reranker_params = (
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
)
strategy = ranker.get("strategy", "rrf")
if strategy == "weighted":
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
reranker_type = RERANKER_TYPE_WEIGHTED
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
else:
reranker_type = RERANKER_TYPE_RRF
k_value = ranker.get("params", {}).get("k", 60.0)
reranker_params = {"impact_factor": k_value}
query_string = interleaved_content_as_str(query)
if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold)
# Calculate embeddings for both vector and hybrid modes
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
if mode == "hybrid":