From e3928e6a298226ac73b11a9e3874622f61072626 Mon Sep 17 00:00:00 2001 From: Varsha Date: Thu, 7 Aug 2025 00:42:03 -0700 Subject: [PATCH] feat: Implement hybrid search in Milvus (#2644) # What does this PR do? This PR implements hybrid search for Milvus DB based on the inbuilt milvus support. To test: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=long --disable-warnings --asyncio-mode=auto ``` Signed-off-by: Varsha Prasad Narsing --- .../remote/vector_io/milvus/milvus.py | 51 ++++++- .../providers/utils/memory/vector_store.py | 16 +- .../vector_io/test_openai_vector_stores.py | 5 + .../providers/vector_io/remote/test_milvus.py | 141 ++++++++++++++++++ 4 files changed, 204 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index db58bf6d3..b09edb65c 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -10,7 +10,7 @@ import os from typing import Any from numpy.typing import NDArray -from pymilvus import DataType, Function, FunctionType, MilvusClient +from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files.files import Files @@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( + RERANKER_TYPE_WEIGHTED, EmbeddingIndex, VectorDBWithIndex, ) @@ -238,7 +239,53 @@ class MilvusIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in Milvus") + """ + Hybrid search using Milvus's native hybrid search capabilities. + + This implementation uses Milvus's hybrid_search method which combines + vector search and BM25 search with configurable reranking strategies. + """ + search_requests = [] + + # nprobe: Controls search accuracy vs performance trade-off + # 10 balances these trade-offs for RAG applications + search_requests.append( + AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k) + ) + + # drop_ratio_search: Filters low-importance terms to improve search performance + # 0.2 balances noise reduction with recall + search_requests.append( + AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k) + ) + + if reranker_type == RERANKER_TYPE_WEIGHTED: + alpha = (reranker_params or {}).get("alpha", 0.5) + rerank = WeightedRanker(alpha, 1 - alpha) + else: + impact_factor = (reranker_params or {}).get("impact_factor", 60.0) + rerank = RRFRanker(impact_factor) + + search_res = await asyncio.to_thread( + self.client.hybrid_search, + collection_name=self.collection_name, + reqs=search_requests, + ranker=rerank, + limit=k, + output_fields=["chunk_content"], + ) + + chunks = [] + scores = [] + for res in search_res[0]: + chunk = Chunk(**res["entity"]["chunk_content"]) + chunks.append(chunk) + scores.append(res["distance"]) + + filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold] + filtered_scores = [score for score in scores if score >= score_threshold] + + return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) async def delete_chunk(self, chunk_id: str) -> None: """Remove a chunk from the Milvus collection.""" diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 484475e9d..bb9002f30 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -302,23 +302,25 @@ class VectorDBWithIndex: mode = params.get("mode") score_threshold = params.get("score_threshold", 0.0) - # Get ranker configuration ranker = params.get("ranker") if ranker is None: - # Default to RRF with impact_factor=60.0 reranker_type = RERANKER_TYPE_RRF reranker_params = {"impact_factor": 60.0} else: - reranker_type = ranker.type - reranker_params = ( - {"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha} - ) + strategy = ranker.get("strategy", "rrf") + if strategy == "weighted": + weights = ranker.get("params", {}).get("weights", [0.5, 0.5]) + reranker_type = RERANKER_TYPE_WEIGHTED + reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5} + else: + reranker_type = RERANKER_TYPE_RRF + k_value = ranker.get("params", {}).get("k", 60.0) + reranker_params = {"impact_factor": k_value} query_string = interleaved_content_as_str(query) if mode == "keyword": return await self.index.query_keyword(query_string, k, score_threshold) - # Calculate embeddings for both vector and hybrid modes embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) if mode == "hybrid": diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 1c9ef92b6..3212a7568 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -30,6 +30,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): "remote::qdrant", "inline::qdrant", "remote::weaviate", + "remote::milvus", ]: return @@ -49,12 +50,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "remote::chromadb", "remote::weaviate", "remote::qdrant", + "remote::milvus", ], "keyword": [ "inline::sqlite-vec", + "remote::milvus", ], "hybrid": [ "inline::sqlite-vec", + "inline::milvus", + "remote::milvus", ], } supported_providers = search_mode_support.get(search_mode, []) diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 145edf7fb..ca5f45fa2 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -15,6 +15,9 @@ from llama_stack.apis.vector_io import QueryChunksResponse pymilvus_mock = MagicMock() pymilvus_mock.DataType = MagicMock() pymilvus_mock.MilvusClient = MagicMock +pymilvus_mock.RRFRanker = MagicMock +pymilvus_mock.WeightedRanker = MagicMock +pymilvus_mock.AnnSearchRequest = MagicMock # Apply the mock before importing MilvusIndex with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): @@ -183,3 +186,141 @@ async def test_delete_collection(milvus_index, mock_milvus_client): await milvus_index.delete() mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name) + + +async def test_query_hybrid_search_rrf( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with RRF reranker.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Test hybrid search with RRF reranker + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=2, + score_threshold=0.0, + reranker_type="rrf", + reranker_params={"impact_factor": 60.0}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + # Verify hybrid search was called with correct parameters + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + + # Check that the request contains both vector and BM25 search requests + reqs = call_args[1]["reqs"] + assert len(reqs) == 2 + assert reqs[0].anns_field == "vector" + assert reqs[1].anns_field == "sparse" + ranker = call_args[1]["ranker"] + assert ranker is not None + + +async def test_query_hybrid_search_weighted( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with weighted reranker.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Test hybrid search with weighted reranker + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=2, + score_threshold=0.0, + reranker_type="weighted", + reranker_params={"alpha": 0.7}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + # Verify hybrid search was called with correct parameters + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + ranker = call_args[1]["ranker"] + assert ranker is not None + + +async def test_query_hybrid_search_default_rrf( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with default RRF reranker (no reranker_type specified).""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + ] + ] + + # Test hybrid search with default reranker (should be RRF) + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=1, + score_threshold=0.0, + reranker_type="unknown_type", # Should default to RRF + reranker_params=None, # Should use default impact_factor + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 1 + + # Verify hybrid search was called with RRF reranker + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + ranker = call_args[1]["ranker"] + assert ranker is not None