diff --git a/docs/source/providers/vector_io/milvus.md b/docs/source/providers/vector_io/milvus.md index e030c85f8..58d33e322 100644 --- a/docs/source/providers/vector_io/milvus.md +++ b/docs/source/providers/vector_io/milvus.md @@ -96,11 +96,20 @@ vector_io: #### Key Parameters for TLS Configuration - **`secure`**: Enables TLS encryption when set to `true`. Defaults to `false`. -- **`server_pem_path`**: Path to the **server certificate** for verifying the server’s identity (used in one-way TLS). +- **`server_pem_path`**: Path to the **server certificate** for verifying the server's identity (used in one-way TLS). - **`ca_pem_path`**: Path to the **Certificate Authority (CA) certificate** for validating the server certificate (required in mTLS). - **`client_pem_path`**: Path to the **client certificate** file (required for mTLS). - **`client_key_path`**: Path to the **client private key** file (required for mTLS). +## Supported Search Modes + +The Milvus provider supports both vector-based and keyword-based (full-text) search modes, but with some limitations: + +- Remote Milvus supports both vector-based and keyword-based search modes. +- Inline Milvus (Milvus-Lite) only supports vector-based search. Keyword search is not supported as Milvus-Lite has not implemented this functionality yet. For updates on this feature, see [Milvus GitHub Issue #40848](https://github.com/milvus-io/milvus/issues/40848). + +When using the RAGTool interface, you can specify the desired search behavior via the `mode` parameter in `RAGQueryConfig`. For more details on Milvus's implementation of keyword search modes, refer to the [Milvus documentation](https://milvus.io/docs/full_text_search_with_milvus.md). + ## Documentation See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for more details about Milvus in general. diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 182227a85..cf5a396d5 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -12,7 +12,7 @@ import uuid from typing import Any from numpy.typing import NDArray -from pymilvus import MilvusClient +from pymilvus import DataType, MilvusClient from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB @@ -48,6 +48,8 @@ class MilvusIndex(EmbeddingIndex): self.client = client self.collection_name = collection_name.replace("-", "_") self.consistency_level = consistency_level + self.bm25 = None + self.vectorizer = None async def delete(self): if await asyncio.to_thread(self.client.has_collection, self.collection_name): @@ -58,11 +60,42 @@ class MilvusIndex(EmbeddingIndex): f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + # Create schema for vector search + schema = self.client.create_schema() + schema.add_field( + field_name="chunk_id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=100, + ) + schema.add_field( + field_name="content", + datatype=DataType.VARCHAR, + max_length=65535, + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=len(embeddings[0]), + ) + schema.add_field( + field_name="chunk_content", + datatype=DataType.JSON, + ) + + # Create indexes + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", + index_type="FLAT", + metric_type="COSINE", + ) + await asyncio.to_thread( self.client.create_collection, self.collection_name, - dimension=len(embeddings[0]), - auto_id=True, + schema=schema, + index_params=index_params, consistency_level=self.consistency_level, ) @@ -73,6 +106,7 @@ class MilvusIndex(EmbeddingIndex): data.append( { "chunk_id": chunk_id, + "content": chunk.content, "vector": embedding, "chunk_content": chunk.model_dump(), } @@ -92,9 +126,10 @@ class MilvusIndex(EmbeddingIndex): self.client.search, collection_name=self.collection_name, data=[embedding], + anns_field="vector", limit=k, output_fields=["*"], - search_params={"params": {"radius": score_threshold}}, + search_params={"metric_type": "COSINE", "params": {"score_threshold": score_threshold}}, ) chunks = [Chunk(**res["entity"]["chunk_content"]) for res in search_res[0]] scores = [res["distance"] for res in search_res[0]] @@ -106,7 +141,17 @@ class MilvusIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Milvus") + # Simple text search using content field + search_res = await asyncio.to_thread( + self.client.query, + collection_name=self.collection_name, + filter=f'content like "%{query_string}%"', + output_fields=["*"], + limit=k, + ) + chunks = [Chunk(**res["chunk_content"]) for res in search_res] + scores = [1.0] * len(chunks) # Simple binary score for text search + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, @@ -200,6 +245,14 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): if not index: raise ValueError(f"Vector DB {vector_db_id} not found") + if params and params.get("mode") == "keyword": + # Check if this is inline Milvus (Milvus-Lite) + if hasattr(self.config, "db_path"): + raise NotImplementedError( + "Keyword search is not supported in Milvus-Lite. " + "Please use a remote Milvus server for keyword search functionality." + ) + return await index.query_chunks(query, params) async def openai_create_vector_store( diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py new file mode 100644 index 000000000..cc2f96ff3 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import pytest_asyncio + +from llama_stack.apis.vector_io import QueryChunksResponse + +# Mock the entire pymilvus module +pymilvus_mock = MagicMock() +pymilvus_mock.DataType = MagicMock() +pymilvus_mock.MilvusClient = MagicMock + +# Apply the mock before importing MilvusIndex +with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex + +# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_milvus.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + +MILVUS_PROVIDER = "milvus" + + +@pytest_asyncio.fixture +async def mock_milvus_client(): + """Create a mock Milvus client with common method behaviors.""" + client = MagicMock() + + # Mock collection operations + client.has_collection.return_value = False # Initially no collection + client.create_collection.return_value = None + client.drop_collection.return_value = None + + # Mock insert operation + client.insert.return_value = {"insert_count": 10} + + # Mock search operation - return mock results (data should be dict, not JSON string) + client.search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Mock query operation for keyword search (data should be dict, not JSON string) + client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, + }, + ] + + return client + + +@pytest_asyncio.fixture +async def milvus_index(mock_milvus_client): + """Create a MilvusIndex with mocked client.""" + index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") + yield index + # No real cleanup needed since we're using mocks + + +@pytest.mark.asyncio +async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + # Setup: collection doesn't exist initially, then exists after creation + mock_milvus_client.has_collection.side_effect = [False, True] + + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Verify collection was created and data was inserted + mock_milvus_client.create_collection.assert_called_once() + mock_milvus_client.insert.assert_called_once() + + # Verify the insert call had the right number of chunks + insert_call = mock_milvus_client.insert.call_args + assert len(insert_call[1]["data"]) == len(sample_chunks) + + +@pytest.mark.asyncio +async def test_query_chunks_vector( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + # Setup: Add chunks first + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test vector search + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + mock_milvus_client.search.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test keyword search + query_string = "Sentence 5" + response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 3 + mock_milvus_client.query.assert_called_once() + + # Test no results case + mock_milvus_client.query.return_value = [] + response_no_results = await milvus_index.query_keyword(query_string="nonexistent", k=1, score_threshold=0.0) + + assert isinstance(response_no_results, QueryChunksResponse) + assert len(response_no_results.chunks) == 0 + + +@pytest.mark.asyncio +async def test_query_chunks_keyword_search_k_greater_than_results( + milvus_index, sample_chunks, sample_embeddings, mock_milvus_client +): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock returning only 1 result even though k=5 + mock_milvus_client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "Sentence 1 from document 0", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + } + ] + + query_str = "Sentence 1 from document 0" + response = await milvus_index.query_keyword(query_string=query_str, k=5, score_threshold=0.0) + + assert 0 < len(response.chunks) <= 4 + assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks) + + +@pytest.mark.asyncio +async def test_delete_collection(milvus_index, mock_milvus_client): + # Test collection deletion + mock_milvus_client.has_collection.return_value = True + + await milvus_index.delete() + + mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)