feat: Implement hybrid search in Milvus (#2644)
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 5s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 10s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.12) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 15s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 8s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 11s
Test External API and Providers / test-external (venv) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 19s
Pre-commit / pre-commit (push) Successful in 57s

# What does this PR do?
This PR implements hybrid search for Milvus DB based on the inbuilt
milvus support.
   
    To test:
    ```
pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s
--tb=long --disable-warnings --asyncio-mode=auto
    ```

Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
Varsha 2025-08-07 00:42:03 -07:00 committed by GitHub
parent 5a2d323eca
commit e3928e6a29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 204 additions and 9 deletions

View file

@ -10,7 +10,7 @@ import os
from typing import Any from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import DataType, Function, FunctionType, MilvusClient from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files from llama_stack.apis.files.files import Files
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_WEIGHTED,
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
) )
@ -238,7 +239,53 @@ class MilvusIndex(EmbeddingIndex):
reranker_type: str, reranker_type: str,
reranker_params: dict[str, Any] | None = None, reranker_params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Hybrid search is not supported in Milvus") """
Hybrid search using Milvus's native hybrid search capabilities.
This implementation uses Milvus's hybrid_search method which combines
vector search and BM25 search with configurable reranking strategies.
"""
search_requests = []
# nprobe: Controls search accuracy vs performance trade-off
# 10 balances these trade-offs for RAG applications
search_requests.append(
AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k)
)
# drop_ratio_search: Filters low-importance terms to improve search performance
# 0.2 balances noise reduction with recall
search_requests.append(
AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k)
)
if reranker_type == RERANKER_TYPE_WEIGHTED:
alpha = (reranker_params or {}).get("alpha", 0.5)
rerank = WeightedRanker(alpha, 1 - alpha)
else:
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread(
self.client.hybrid_search,
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
limit=k,
output_fields=["chunk_content"],
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"])
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
async def delete_chunk(self, chunk_id: str) -> None: async def delete_chunk(self, chunk_id: str) -> None:
"""Remove a chunk from the Milvus collection.""" """Remove a chunk from the Milvus collection."""

View file

@ -302,23 +302,25 @@ class VectorDBWithIndex:
mode = params.get("mode") mode = params.get("mode")
score_threshold = params.get("score_threshold", 0.0) score_threshold = params.get("score_threshold", 0.0)
# Get ranker configuration
ranker = params.get("ranker") ranker = params.get("ranker")
if ranker is None: if ranker is None:
# Default to RRF with impact_factor=60.0
reranker_type = RERANKER_TYPE_RRF reranker_type = RERANKER_TYPE_RRF
reranker_params = {"impact_factor": 60.0} reranker_params = {"impact_factor": 60.0}
else: else:
reranker_type = ranker.type strategy = ranker.get("strategy", "rrf")
reranker_params = ( if strategy == "weighted":
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha} weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
) reranker_type = RERANKER_TYPE_WEIGHTED
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
else:
reranker_type = RERANKER_TYPE_RRF
k_value = ranker.get("params", {}).get("k", 60.0)
reranker_params = {"impact_factor": k_value}
query_string = interleaved_content_as_str(query) query_string = interleaved_content_as_str(query)
if mode == "keyword": if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold) return await self.index.query_keyword(query_string, k, score_threshold)
# Calculate embeddings for both vector and hybrid modes
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string]) embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32) query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
if mode == "hybrid": if mode == "hybrid":

View file

@ -30,6 +30,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
"remote::qdrant", "remote::qdrant",
"inline::qdrant", "inline::qdrant",
"remote::weaviate", "remote::weaviate",
"remote::milvus",
]: ]:
return return
@ -49,12 +50,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
"remote::chromadb", "remote::chromadb",
"remote::weaviate", "remote::weaviate",
"remote::qdrant", "remote::qdrant",
"remote::milvus",
], ],
"keyword": [ "keyword": [
"inline::sqlite-vec", "inline::sqlite-vec",
"remote::milvus",
], ],
"hybrid": [ "hybrid": [
"inline::sqlite-vec", "inline::sqlite-vec",
"inline::milvus",
"remote::milvus",
], ],
} }
supported_providers = search_mode_support.get(search_mode, []) supported_providers = search_mode_support.get(search_mode, [])

View file

@ -15,6 +15,9 @@ from llama_stack.apis.vector_io import QueryChunksResponse
pymilvus_mock = MagicMock() pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock() pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock pymilvus_mock.MilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
# Apply the mock before importing MilvusIndex # Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
@ -183,3 +186,141 @@ async def test_delete_collection(milvus_index, mock_milvus_client):
await milvus_index.delete() await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name) mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with RRF reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
# Check that the request contains both vector and BM25 search requests
reqs = call_args[1]["reqs"]
assert len(reqs) == 2
assert reqs[0].anns_field == "vector"
assert reqs[1].anns_field == "sparse"
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with weighted reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
]
]
# Test hybrid search with default reranker (should be RRF)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="unknown_type", # Should default to RRF
reranker_params=None, # Should use default impact_factor
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 1
# Verify hybrid search was called with RRF reranker
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None