diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index f3089e615..6734d8315 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi | `uri` | `` | No | PydanticUndefined | The URI of the Milvus server | | `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | | `consistency_level` | `` | No | Strong | The consistency level of the Milvus server | -| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | > **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. @@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi ```yaml uri: ${env.MILVUS_ENDPOINT} token: ${env.MILVUS_TOKEN} +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db ``` diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py index e3f51b4f4..899d3678d 100644 --- a/llama_stack/providers/remote/vector_io/milvus/config.py +++ b/llama_stack/providers/remote/vector_io/milvus/config.py @@ -8,7 +8,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from llama_stack.providers.utils.kvstore.config import KVStoreConfig +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.schema_utils import json_schema_type @@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel): uri: str = Field(description="The URI of the Milvus server") token: str | None = Field(description="The token of the Milvus server") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") - kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) + kvstore: KVStoreConfig = Field(description="Config for KV store backend") # This configuration allows additional fields to be passed through to the underlying Milvus client. # See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. @@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel): @classmethod def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: - return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} + return { + "uri": "${env.MILVUS_ENDPOINT}", + "token": "${env.MILVUS_TOKEN}", + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="milvus_remote_registry.db", + ), + } diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index d88d954ef..f301942cb 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -12,7 +12,7 @@ import re from typing import Any from numpy.typing import NDArray -from pymilvus import DataType, MilvusClient +from pymilvus import DataType, Function, FunctionType, MilvusClient from llama_stack.apis.files.files import Files from llama_stack.apis.inference import Inference, InterleavedContent @@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) + if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") + # Create schema for vector search + schema = self.client.create_schema() + schema.add_field( + field_name="chunk_id", + datatype=DataType.VARCHAR, + is_primary=True, + max_length=100, + ) + schema.add_field( + field_name="content", + datatype=DataType.VARCHAR, + max_length=65535, + enable_analyzer=True, # Enable text analysis for BM25 + ) + schema.add_field( + field_name="vector", + datatype=DataType.FLOAT_VECTOR, + dim=len(embeddings[0]), + ) + schema.add_field( + field_name="chunk_content", + datatype=DataType.JSON, + ) + # Add sparse vector field for BM25 (required by the function) + schema.add_field( + field_name="sparse", + datatype=DataType.SPARSE_FLOAT_VECTOR, + ) + + # Create indexes + index_params = self.client.prepare_index_params() + index_params.add_index( + field_name="vector", + index_type="FLAT", + metric_type="COSINE", + ) + # Add index for sparse field (required by BM25 function) + index_params.add_index( + field_name="sparse", + index_type="SPARSE_INVERTED_INDEX", + metric_type="BM25", + ) + + # Add BM25 function for full-text search + bm25_function = Function( + name="text_bm25_emb", + input_field_names=["content"], + output_field_names=["sparse"], + function_type=FunctionType.BM25, + ) + schema.add_function(bm25_function) + await asyncio.to_thread( self.client.create_collection, self.collection_name, - dimension=len(embeddings[0]), - auto_id=True, + schema=schema, + index_params=index_params, consistency_level=self.consistency_level, ) @@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex): data.append( { "chunk_id": chunk.chunk_id, + "content": chunk.content, "vector": embedding, "chunk_content": chunk.model_dump(), + # sparse field will be handled by BM25 function automatically } ) try: @@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex): self.client.search, collection_name=self.collection_name, data=[embedding], + anns_field="vector", limit=k, output_fields=["*"], search_params={"params": {"radius": score_threshold}}, @@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex): k: int, score_threshold: float, ) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Milvus") + """ + Perform BM25-based keyword search using Milvus's built-in full-text search. + """ + try: + # Use Milvus's built-in BM25 search + search_res = await asyncio.to_thread( + self.client.search, + collection_name=self.collection_name, + data=[query_string], # Raw text query + anns_field="sparse", # Use sparse field for BM25 + output_fields=["chunk_content"], # Output the chunk content + limit=k, + search_params={ + "params": { + "drop_ratio_search": 0.2, # Ignore low-importance terms + } + }, + ) + + chunks = [] + scores = [] + for res in search_res[0]: + chunk = Chunk(**res["entity"]["chunk_content"]) + chunks.append(chunk) + scores.append(res["distance"]) # BM25 score from Milvus + + # Filter by score threshold + filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold] + filtered_scores = [score for score in scores if score >= score_threshold] + + return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores) + + except Exception as e: + logger.error(f"Error performing BM25 search: {e}") + # Fallback to simple text search + return await self._fallback_keyword_search(query_string, k, score_threshold) + + async def _fallback_keyword_search( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """ + Fallback to simple text search when BM25 search is not available. + """ + # Simple text search using content field + search_res = await asyncio.to_thread( + self.client.query, + collection_name=self.collection_name, + filter='content like "%{content}%"', + filter_params={"content": query_string}, + output_fields=["*"], + limit=k, + ) + chunks = [Chunk(**res["chunk_content"]) for res in search_res] + scores = [1.0] * len(chunks) # Simple binary score for text search + return QueryChunksResponse(chunks=chunks, scores=scores) async def query_hybrid( self, @@ -247,6 +361,14 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP if not index: raise ValueError(f"Vector DB {vector_db_id} not found") + if params and params.get("mode") == "keyword": + # Check if this is inline Milvus (Milvus-Lite) + if hasattr(self.config, "db_path"): + raise NotImplementedError( + "Keyword search is not supported in Milvus-Lite. " + "Please use a remote Milvus server for keyword search functionality." + ) + return await index.query_chunks(query, params) async def _save_openai_vector_store_file( diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py new file mode 100644 index 000000000..2f212e374 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +import pytest_asyncio + +from llama_stack.apis.vector_io import QueryChunksResponse + +# Mock the entire pymilvus module +pymilvus_mock = MagicMock() +pymilvus_mock.DataType = MagicMock() +pymilvus_mock.MilvusClient = MagicMock + +# Apply the mock before importing MilvusIndex +with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex + +# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_milvus.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + +MILVUS_PROVIDER = "milvus" + + +@pytest_asyncio.fixture +async def mock_milvus_client() -> MagicMock: + """Create a mock Milvus client with common method behaviors.""" + client = MagicMock() + + # Mock collection operations + client.has_collection.return_value = False # Initially no collection + client.create_collection.return_value = None + client.drop_collection.return_value = None + + # Mock insert operation + client.insert.return_value = {"insert_count": 10} + + # Mock search operation - return mock results (data should be dict, not JSON string) + client.search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Mock query operation for keyword search (data should be dict, not JSON string) + client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, + }, + ] + + return client + + +@pytest_asyncio.fixture +async def milvus_index(mock_milvus_client): + """Create a MilvusIndex with mocked client.""" + index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") + yield index + # No real cleanup needed since we're using mocks + + +@pytest.mark.asyncio +async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + # Setup: collection doesn't exist initially, then exists after creation + mock_milvus_client.has_collection.side_effect = [False, True] + + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Verify collection was created and data was inserted + mock_milvus_client.create_collection.assert_called_once() + mock_milvus_client.insert.assert_called_once() + + # Verify the insert call had the right number of chunks + insert_call = mock_milvus_client.insert.call_args + assert len(insert_call[1]["data"]) == len(sample_chunks) + + +@pytest.mark.asyncio +async def test_query_chunks_vector( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + # Setup: Add chunks first + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test vector search + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + mock_milvus_client.search.assert_called_once() + + +@pytest.mark.asyncio +async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test keyword search + query_string = "Sentence 5" + response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + + +@pytest.mark.asyncio +async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + """Test that when BM25 search fails, the system falls back to simple text search.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Force BM25 search to fail + mock_milvus_client.search.side_effect = Exception("BM25 search not available") + + # Mock simple text search results + mock_milvus_client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}}, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}}, + }, + ] + + # Test keyword search that should fall back to simple text search + query_string = "Python" + response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) + + # Verify response structure + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0, "Fallback search should return results" + + # Verify that simple text search was used (query method called instead of search) + mock_milvus_client.query.assert_called_once() + mock_milvus_client.search.assert_called_once() # Called once but failed + + # Verify the query uses parameterized filter with filter_params + query_call_args = mock_milvus_client.query.call_args + assert "filter" in query_call_args[1], "Query should include filter for text search" + assert "filter_params" in query_call_args[1], "Query should use parameterized filter" + assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term" + + # Verify all returned chunks have score 1.0 (simple binary scoring) + assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" + + +@pytest.mark.asyncio +async def test_delete_collection(milvus_index, mock_milvus_client): + # Test collection deletion + mock_milvus_client.has_collection.return_value = True + + await milvus_index.delete() + + mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)