mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: Implement hybrid search in Milvus (#2644)
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 5s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 10s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.12) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 15s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 8s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 11s
Test External API and Providers / test-external (venv) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 19s
Pre-commit / pre-commit (push) Successful in 57s
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 5s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 10s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.12) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 15s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 8s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 11s
Test External API and Providers / test-external (venv) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 19s
Pre-commit / pre-commit (push) Successful in 57s
# What does this PR do? This PR implements hybrid search for Milvus DB based on the inbuilt milvus support. To test: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=long --disable-warnings --asyncio-mode=auto ``` Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
5a2d323eca
commit
e3928e6a29
4 changed files with 204 additions and 9 deletions
|
@ -10,7 +10,7 @@ import os
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.files.files import Files
|
||||
|
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
|||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
|
@ -238,7 +239,53 @@ class MilvusIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
"""
|
||||
Hybrid search using Milvus's native hybrid search capabilities.
|
||||
|
||||
This implementation uses Milvus's hybrid_search method which combines
|
||||
vector search and BM25 search with configurable reranking strategies.
|
||||
"""
|
||||
search_requests = []
|
||||
|
||||
# nprobe: Controls search accuracy vs performance trade-off
|
||||
# 10 balances these trade-offs for RAG applications
|
||||
search_requests.append(
|
||||
AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k)
|
||||
)
|
||||
|
||||
# drop_ratio_search: Filters low-importance terms to improve search performance
|
||||
# 0.2 balances noise reduction with recall
|
||||
search_requests.append(
|
||||
AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k)
|
||||
)
|
||||
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = (reranker_params or {}).get("alpha", 0.5)
|
||||
rerank = WeightedRanker(alpha, 1 - alpha)
|
||||
else:
|
||||
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
||||
rerank = RRFRanker(impact_factor)
|
||||
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.hybrid_search,
|
||||
collection_name=self.collection_name,
|
||||
reqs=search_requests,
|
||||
ranker=rerank,
|
||||
limit=k,
|
||||
output_fields=["chunk_content"],
|
||||
)
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for res in search_res[0]:
|
||||
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||
chunks.append(chunk)
|
||||
scores.append(res["distance"])
|
||||
|
||||
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||
|
||||
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> None:
|
||||
"""Remove a chunk from the Milvus collection."""
|
||||
|
|
|
@ -302,23 +302,25 @@ class VectorDBWithIndex:
|
|||
mode = params.get("mode")
|
||||
score_threshold = params.get("score_threshold", 0.0)
|
||||
|
||||
# Get ranker configuration
|
||||
ranker = params.get("ranker")
|
||||
if ranker is None:
|
||||
# Default to RRF with impact_factor=60.0
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
reranker_params = {"impact_factor": 60.0}
|
||||
else:
|
||||
reranker_type = ranker.type
|
||||
reranker_params = (
|
||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
||||
)
|
||||
strategy = ranker.get("strategy", "rrf")
|
||||
if strategy == "weighted":
|
||||
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||
else:
|
||||
reranker_type = RERANKER_TYPE_RRF
|
||||
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||
reranker_params = {"impact_factor": k_value}
|
||||
|
||||
query_string = interleaved_content_as_str(query)
|
||||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Calculate embeddings for both vector and hybrid modes
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
|
|
|
@ -30,6 +30,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
|||
"remote::qdrant",
|
||||
"inline::qdrant",
|
||||
"remote::weaviate",
|
||||
"remote::milvus",
|
||||
]:
|
||||
return
|
||||
|
||||
|
@ -49,12 +50,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
|||
"remote::chromadb",
|
||||
"remote::weaviate",
|
||||
"remote::qdrant",
|
||||
"remote::milvus",
|
||||
],
|
||||
"keyword": [
|
||||
"inline::sqlite-vec",
|
||||
"remote::milvus",
|
||||
],
|
||||
"hybrid": [
|
||||
"inline::sqlite-vec",
|
||||
"inline::milvus",
|
||||
"remote::milvus",
|
||||
],
|
||||
}
|
||||
supported_providers = search_mode_support.get(search_mode, [])
|
||||
|
|
|
@ -15,6 +15,9 @@ from llama_stack.apis.vector_io import QueryChunksResponse
|
|||
pymilvus_mock = MagicMock()
|
||||
pymilvus_mock.DataType = MagicMock()
|
||||
pymilvus_mock.MilvusClient = MagicMock
|
||||
pymilvus_mock.RRFRanker = MagicMock
|
||||
pymilvus_mock.WeightedRanker = MagicMock
|
||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||
|
||||
# Apply the mock before importing MilvusIndex
|
||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||
|
@ -183,3 +186,141 @@ async def test_delete_collection(milvus_index, mock_milvus_client):
|
|||
await milvus_index.delete()
|
||||
|
||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||
|
||||
|
||||
async def test_query_hybrid_search_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with RRF reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with RRF reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
|
||||
# Check that the request contains both vector and BM25 search requests
|
||||
reqs = call_args[1]["reqs"]
|
||||
assert len(reqs) == 2
|
||||
assert reqs[0].anns_field == "vector"
|
||||
assert reqs[1].anns_field == "sparse"
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_weighted(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with weighted reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with weighted reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_default_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with default reranker (should be RRF)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="unknown_type", # Should default to RRF
|
||||
reranker_params=None, # Should use default impact_factor
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 1
|
||||
|
||||
# Verify hybrid search was called with RRF reranker
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue