diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 9017f0e22..b81716192 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -101,6 +101,15 @@ vector_io: - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS). +## Supported Search Modes + +The Milvus provider supports both vector-based and keyword-based (full-text) search modes, but with some limitations: + +- Remote Milvus supports both vector-based and keyword-based search modes. +- Inline Milvus (Milvus-Lite) only supports vector-based search. Keyword search is not supported as Milvus-Lite has not implemented this functionality yet. For updates on this feature, see [Milvus GitHub Issue #40848](https://github.com/milvus-io/milvus/issues/40848). + +When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For more details on Milvus's implementation of keyword search modes, refer to the [Milvus documentation](https://milvus.io/docs/full_text_search_with_milvus.md). + ## Documentation See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 25fe237c0..387e22382 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -70,11 +70,58 @@ class MilvusIndex(EmbeddingIndex): f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + # Create schema for vector search + schema = self.client.create_schema() + schema.add_field( + field_name="chunk_id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=100, + ) + schema.add_field( + field_name="content", + datatype=DataType.VARCHAR, + max_length=65535, + enable_analyzer=True, # Enable text analysis for BM25 + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=len(embeddings[0]), + ) + schema.add_field( + field_name="chunk_content", + datatype=DataType.JSON, + ) + # Add sparse vector field for BM25 + schema.add_field( + field_name="sparse", + datatype=DataType.SPARSE_FLOAT_VECTOR, + ) + + # Create indexes + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", + index_type="FLAT", + metric_type="COSINE", + ) + + # Add BM25 function for full-text search + from pymilvus import Function, FunctionType + bm25_function = Function( + name="text_bm25_emb", + input_field_names=["content"], + output_field_names=["sparse"], + function_type=FunctionType.BM25, + ) + schema.add_function(bm25_function) + await asyncio.to_thread( self.client.create_collection, self.collection_name, - dimension=len(embeddings[0]), - auto_id=True, + schema=schema, + index_params=index_params, consistency_level=self.consistency_level, ) @@ -83,8 +130,10 @@ class MilvusIndex(EmbeddingIndex): data.append( { "chunk_id": chunk.chunk_id, + "content": chunk.content, "vector": embedding, "chunk_content": chunk.model_dump(), + # sparse field will be automatically populated by BM25 function } ) try: @@ -102,9 +151,10 @@ class MilvusIndex(EmbeddingIndex): self.client.search, collection_name=self.collection_name, data=[embedding], + anns_field="vector", limit=k, output_fields=["*"], - search_params={"params": {"radius": score_threshold}}, + search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}}, ) chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] scores = [res["distance"] for res in search_res[0]] @@ -116,7 +166,41 @@ class MilvusIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Milvus") + """ + Perform BM25-based keyword search using Milvus's built-in full-text search. + """ + try: + from pymilvus import Function, FunctionType + search_res = await asyncio.to_thread( + self.client.search, + collection_name=self.collection_name, + data=[query_string], # Raw text query + anns_field="sparse", # Use sparse field for BM25 + output_fields=["chunk_content"], # Output the chunk content + limit=k, + search_params={ + "params": { + "drop_ratio_search": 0.2, # Ignore low-importance terms + } + }, + ) + chunks = [] + scores = [] + for res in search_res[0]: + chunk = Chunk(**res["entity"]["chunk_content"]) + chunks.append(chunk) + scores.append(res["distance"]) # BM25 score from Milvus + # Filter by score threshold + filtered_results = [(chunk, score) for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold] + if filtered_results: + chunks, scores = zip(*filtered_results, strict=False) + return QueryChunksResponse(chunks=list(chunks), scores=list(scores)) + else: + return QueryChunksResponse(chunks=[], scores=[]) + except Exception as e: + logger.error(f"Error performing BM25 search: {e}") + # Fallback to simple text search + return await self._fallback_keyword_search(query_string, k, score_threshold) async def query_hybrid( self, @@ -238,6 +322,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP if not index: raise ValueError(f"Vector DB {vector_db_id} not found") + if params and params.get("mode") == "keyword": + # Check if this is inline Milvus (Milvus-Lite) + if hasattr(self.config, "db_path"): + raise NotImplementedError( + "Keyword search is not supported in Milvus-Lite. " + "Please use a remote Milvus server for keyword search functionality." + ) + return await index.query_chunks(query, params) async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py new file mode 100644 index 000000000..cc2f96ff3 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import pytest_asyncio + +from llama_stack.apis.vector_io import QueryChunksResponse + +# Mock the entire pymilvus module +pymilvus_mock = MagicMock() +pymilvus_mock.DataType = MagicMock() +pymilvus_mock.MilvusClient = MagicMock + +# Apply the mock before importing MilvusIndex +with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex + +# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_milvus.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + +MILVUS_PROVIDER = "milvus" + + +@pytest_asyncio.fixture +async def mock_milvus_client(): + """Create a mock Milvus client with common method behaviors.""" + client = MagicMock() + + # Mock collection operations + client.has_collection.return_value = False # Initially no collection + client.create_collection.return_value = None + client.drop_collection.return_value = None + + # Mock insert operation + client.insert.return_value = {"insert_count": 10} + + # Mock search operation - return mock results (data should be dict, not JSON string) + client.search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Mock query operation for keyword search (data should be dict, not JSON string) + client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, + }, + ] + + return client + + +@pytest_asyncio.fixture +async def milvus_index(mock_milvus_client): + """Create a MilvusIndex with mocked client.""" + index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") + yield index + # No real cleanup needed since we're using mocks + + +@pytest.mark.asyncio +async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + # Setup: collection doesn't exist initially, then exists after creation + mock_milvus_client.has_collection.side_effect = [False, True] + + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Verify collection was created and data was inserted + mock_milvus_client.create_collection.assert_called_once() + mock_milvus_client.insert.assert_called_once() + + # Verify the insert call had the right number of chunks + insert_call = mock_milvus_client.insert.call_args + assert len(insert_call[1]["data"]) == len(sample_chunks) + + +@pytest.mark.asyncio +async def test_query_chunks_vector( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + # Setup: Add chunks first + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test vector search + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + mock_milvus_client.search.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test keyword search + query_string = "Sentence 5" + response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 3 + mock_milvus_client.query.assert_called_once() + + # Test no results case + mock_milvus_client.query.return_value = [] + response_no_results = await milvus_index.query_keyword(query_string="nonexistent", k=1, score_threshold=0.0) + + assert isinstance(response_no_results, QueryChunksResponse) + assert len(response_no_results.chunks) == 0 + + +@pytest.mark.asyncio +async def test_query_chunks_keyword_search_k_greater_than_results( + milvus_index, sample_chunks, sample_embeddings, mock_milvus_client +): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock returning only 1 result even though k=5 + mock_milvus_client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "Sentence 1 from document 0", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + } + ] + + query_str = "Sentence 1 from document 0" + response = await milvus_index.query_keyword(query_string=query_str, k=5, score_threshold=0.0) + + assert 0 < len(response.chunks) <= 4 + assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks) + + +@pytest.mark.asyncio +async def test_delete_collection(milvus_index, mock_milvus_client): + # Test collection deletion + mock_milvus_client.has_collection.return_value = True + + await milvus_index.delete() + + mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)