mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
feat: Implement hybrid search in Milvus (#2644)
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 5s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 10s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.12) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 15s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 8s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 11s
Test External API and Providers / test-external (venv) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 19s
Pre-commit / pre-commit (push) Successful in 57s
Some checks failed
Integration Tests (Replay) / discover-tests (push) Successful in 5s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 10s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.13) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 9s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 16s
Python Package Build Test / build (3.12) (push) Failing after 10s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 15s
Unit Tests / unit-tests (3.13) (push) Failing after 10s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 15s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 8s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 8s
Unit Tests / unit-tests (3.12) (push) Failing after 19s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 11s
Test External API and Providers / test-external (venv) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 19s
Pre-commit / pre-commit (push) Successful in 57s
# What does this PR do? This PR implements hybrid search for Milvus DB based on the inbuilt milvus support. To test: ``` pytest tests/unit/providers/vector_io/remote/test_milvus.py -v -s --tb=long --disable-warnings --asyncio-mode=auto ``` Signed-off-by: Varsha Prasad Narsing <varshaprasad96@gmail.com>
This commit is contained in:
parent
5a2d323eca
commit
e3928e6a29
4 changed files with 204 additions and 9 deletions
|
@ -10,7 +10,7 @@ import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
|
@ -27,6 +27,7 @@ from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
|
RERANKER_TYPE_WEIGHTED,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
)
|
)
|
||||||
|
@ -238,7 +239,53 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
reranker_type: str,
|
reranker_type: str,
|
||||||
reranker_params: dict[str, Any] | None = None,
|
reranker_params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
"""
|
||||||
|
Hybrid search using Milvus's native hybrid search capabilities.
|
||||||
|
|
||||||
|
This implementation uses Milvus's hybrid_search method which combines
|
||||||
|
vector search and BM25 search with configurable reranking strategies.
|
||||||
|
"""
|
||||||
|
search_requests = []
|
||||||
|
|
||||||
|
# nprobe: Controls search accuracy vs performance trade-off
|
||||||
|
# 10 balances these trade-offs for RAG applications
|
||||||
|
search_requests.append(
|
||||||
|
AnnSearchRequest(data=[embedding.tolist()], anns_field="vector", param={"nprobe": 10}, limit=k)
|
||||||
|
)
|
||||||
|
|
||||||
|
# drop_ratio_search: Filters low-importance terms to improve search performance
|
||||||
|
# 0.2 balances noise reduction with recall
|
||||||
|
search_requests.append(
|
||||||
|
AnnSearchRequest(data=[query_string], anns_field="sparse", param={"drop_ratio_search": 0.2}, limit=k)
|
||||||
|
)
|
||||||
|
|
||||||
|
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||||
|
alpha = (reranker_params or {}).get("alpha", 0.5)
|
||||||
|
rerank = WeightedRanker(alpha, 1 - alpha)
|
||||||
|
else:
|
||||||
|
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
||||||
|
rerank = RRFRanker(impact_factor)
|
||||||
|
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.hybrid_search,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
reqs=search_requests,
|
||||||
|
ranker=rerank,
|
||||||
|
limit=k,
|
||||||
|
output_fields=["chunk_content"],
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for res in search_res[0]:
|
||||||
|
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(res["distance"])
|
||||||
|
|
||||||
|
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||||
|
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||||
|
|
||||||
async def delete_chunk(self, chunk_id: str) -> None:
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
"""Remove a chunk from the Milvus collection."""
|
"""Remove a chunk from the Milvus collection."""
|
||||||
|
|
|
@ -302,23 +302,25 @@ class VectorDBWithIndex:
|
||||||
mode = params.get("mode")
|
mode = params.get("mode")
|
||||||
score_threshold = params.get("score_threshold", 0.0)
|
score_threshold = params.get("score_threshold", 0.0)
|
||||||
|
|
||||||
# Get ranker configuration
|
|
||||||
ranker = params.get("ranker")
|
ranker = params.get("ranker")
|
||||||
if ranker is None:
|
if ranker is None:
|
||||||
# Default to RRF with impact_factor=60.0
|
|
||||||
reranker_type = RERANKER_TYPE_RRF
|
reranker_type = RERANKER_TYPE_RRF
|
||||||
reranker_params = {"impact_factor": 60.0}
|
reranker_params = {"impact_factor": 60.0}
|
||||||
else:
|
else:
|
||||||
reranker_type = ranker.type
|
strategy = ranker.get("strategy", "rrf")
|
||||||
reranker_params = (
|
if strategy == "weighted":
|
||||||
{"impact_factor": ranker.impact_factor} if ranker.type == RERANKER_TYPE_RRF else {"alpha": ranker.alpha}
|
weights = ranker.get("params", {}).get("weights", [0.5, 0.5])
|
||||||
)
|
reranker_type = RERANKER_TYPE_WEIGHTED
|
||||||
|
reranker_params = {"alpha": weights[0] if len(weights) > 0 else 0.5}
|
||||||
|
else:
|
||||||
|
reranker_type = RERANKER_TYPE_RRF
|
||||||
|
k_value = ranker.get("params", {}).get("k", 60.0)
|
||||||
|
reranker_params = {"impact_factor": k_value}
|
||||||
|
|
||||||
query_string = interleaved_content_as_str(query)
|
query_string = interleaved_content_as_str(query)
|
||||||
if mode == "keyword":
|
if mode == "keyword":
|
||||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||||
|
|
||||||
# Calculate embeddings for both vector and hybrid modes
|
|
||||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
if mode == "hybrid":
|
if mode == "hybrid":
|
||||||
|
|
|
@ -30,6 +30,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||||
"remote::qdrant",
|
"remote::qdrant",
|
||||||
"inline::qdrant",
|
"inline::qdrant",
|
||||||
"remote::weaviate",
|
"remote::weaviate",
|
||||||
|
"remote::milvus",
|
||||||
]:
|
]:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -49,12 +50,16 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
||||||
"remote::chromadb",
|
"remote::chromadb",
|
||||||
"remote::weaviate",
|
"remote::weaviate",
|
||||||
"remote::qdrant",
|
"remote::qdrant",
|
||||||
|
"remote::milvus",
|
||||||
],
|
],
|
||||||
"keyword": [
|
"keyword": [
|
||||||
"inline::sqlite-vec",
|
"inline::sqlite-vec",
|
||||||
|
"remote::milvus",
|
||||||
],
|
],
|
||||||
"hybrid": [
|
"hybrid": [
|
||||||
"inline::sqlite-vec",
|
"inline::sqlite-vec",
|
||||||
|
"inline::milvus",
|
||||||
|
"remote::milvus",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
supported_providers = search_mode_support.get(search_mode, [])
|
supported_providers = search_mode_support.get(search_mode, [])
|
||||||
|
|
|
@ -15,6 +15,9 @@ from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
pymilvus_mock = MagicMock()
|
pymilvus_mock = MagicMock()
|
||||||
pymilvus_mock.DataType = MagicMock()
|
pymilvus_mock.DataType = MagicMock()
|
||||||
pymilvus_mock.MilvusClient = MagicMock
|
pymilvus_mock.MilvusClient = MagicMock
|
||||||
|
pymilvus_mock.RRFRanker = MagicMock
|
||||||
|
pymilvus_mock.WeightedRanker = MagicMock
|
||||||
|
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||||
|
|
||||||
# Apply the mock before importing MilvusIndex
|
# Apply the mock before importing MilvusIndex
|
||||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
|
@ -183,3 +186,141 @@ async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
await milvus_index.delete()
|
await milvus_index.delete()
|
||||||
|
|
||||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_search_rrf(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
"""Test hybrid search with RRF reranker."""
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Mock hybrid search results
|
||||||
|
mock_milvus_client.hybrid_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test hybrid search with RRF reranker
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "test query"
|
||||||
|
response = await milvus_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="rrf",
|
||||||
|
reranker_params={"impact_factor": 60.0},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
|
||||||
|
# Verify hybrid search was called with correct parameters
|
||||||
|
mock_milvus_client.hybrid_search.assert_called_once()
|
||||||
|
call_args = mock_milvus_client.hybrid_search.call_args
|
||||||
|
|
||||||
|
# Check that the request contains both vector and BM25 search requests
|
||||||
|
reqs = call_args[1]["reqs"]
|
||||||
|
assert len(reqs) == 2
|
||||||
|
assert reqs[0].anns_field == "vector"
|
||||||
|
assert reqs[1].anns_field == "sparse"
|
||||||
|
ranker = call_args[1]["ranker"]
|
||||||
|
assert ranker is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_search_weighted(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
"""Test hybrid search with weighted reranker."""
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Mock hybrid search results
|
||||||
|
mock_milvus_client.hybrid_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test hybrid search with weighted reranker
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "test query"
|
||||||
|
response = await milvus_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=2,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="weighted",
|
||||||
|
reranker_params={"alpha": 0.7},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
assert len(response.scores) == 2
|
||||||
|
|
||||||
|
# Verify hybrid search was called with correct parameters
|
||||||
|
mock_milvus_client.hybrid_search.assert_called_once()
|
||||||
|
call_args = mock_milvus_client.hybrid_search.call_args
|
||||||
|
ranker = call_args[1]["ranker"]
|
||||||
|
assert ranker is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_query_hybrid_search_default_rrf(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Mock hybrid search results
|
||||||
|
mock_milvus_client.hybrid_search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test hybrid search with default reranker (should be RRF)
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
query_string = "test query"
|
||||||
|
response = await milvus_index.query_hybrid(
|
||||||
|
embedding=query_embedding,
|
||||||
|
query_string=query_string,
|
||||||
|
k=1,
|
||||||
|
score_threshold=0.0,
|
||||||
|
reranker_type="unknown_type", # Should default to RRF
|
||||||
|
reranker_params=None, # Should use default impact_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 1
|
||||||
|
|
||||||
|
# Verify hybrid search was called with RRF reranker
|
||||||
|
mock_milvus_client.hybrid_search.assert_called_once()
|
||||||
|
call_args = mock_milvus_client.hybrid_search.call_args
|
||||||
|
ranker = call_args[1]["ranker"]
|
||||||
|
assert ranker is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue