diff --git a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py index 73339b5be..d0bee8965 100644 --- a/src/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/src/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import os from typing import Any from numpy.typing import NDArray -from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker +from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files import Files @@ -44,7 +43,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class MilvusIndex(EmbeddingIndex): def __init__( - self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None + self, + client: AsyncMilvusClient, + collection_name: str, + consistency_level="Strong", + kvstore: KVStore | None = None, ): self.client = client self.collection_name = sanitize_collection_name(collection_name) @@ -57,15 +60,15 @@ class MilvusIndex(EmbeddingIndex): pass async def delete(self): - if await asyncio.to_thread(self.client.has_collection, self.collection_name): - await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) + if await self.client.has_collection(self.collection_name): + await self.client.drop_collection(collection_name=self.collection_name) async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + if not await self.client.has_collection(self.collection_name): logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() @@ -96,8 +99,7 @@ class MilvusIndex(EmbeddingIndex): ) schema.add_function(bm25_function) - await asyncio.to_thread( - self.client.create_collection, + await self.client.create_collection( self.collection_name, schema=schema, index_params=index_params, @@ -116,14 +118,16 @@ class MilvusIndex(EmbeddingIndex): } ) try: - await asyncio.to_thread(self.client.insert, self.collection_name, data=data) + await self.client.insert( + self.collection_name, + data=data, + ) except Exception as e: logger.error(f"Error inserting chunks into Milvus collection {self.collection_name}: {e}") raise e async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[embedding], anns_field="vector", @@ -141,8 +145,7 @@ class MilvusIndex(EmbeddingIndex): """ try: # Use Milvus's built-in BM25 search - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[query_string], # Raw text query anns_field="sparse", # Use sparse field for BM25 @@ -178,8 +181,7 @@ class MilvusIndex(EmbeddingIndex): Fallback to simple text search when BM25 search is not available. """ # Simple text search using content field - search_res = await asyncio.to_thread( - self.client.query, + search_res = await self.client.query( collection_name=self.collection_name, filter='content like "%{content}%"', filter_params={"content": query_string}, @@ -226,8 +228,7 @@ class MilvusIndex(EmbeddingIndex): impact_factor = (reranker_params or {}).get("impact_factor", 60.0) rerank = RRFRanker(impact_factor) - search_res = await asyncio.to_thread( - self.client.hybrid_search, + search_res = await self.client.hybrid_search( collection_name=self.collection_name, reqs=search_requests, ranker=rerank, @@ -253,9 +254,7 @@ class MilvusIndex(EmbeddingIndex): try: # Use IN clause with square brackets and single quotes for VARCHAR field chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) - await asyncio.to_thread( - self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" - ) + await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]") except Exception as e: logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise @@ -297,17 +296,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtoc self.cache[vector_store.identifier] = index if isinstance(self.config, RemoteMilvusVectorIOConfig): logger.info(f"Connecting to Milvus server at {self.config.uri}") - self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) else: logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") uri = os.path.expanduser(self.config.db_path) - self.client = MilvusClient(uri=uri) + self.client = AsyncMilvusClient(uri=uri) # Load existing OpenAI vector stores into the in-memory cache await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - self.client.close() + await self.client.close() # Clean up mixin resources (file batch tasks) await super().shutdown() diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 5e56ea417..0960a9c52 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -9,20 +9,26 @@ from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest +from chromadb import PersistentClient +from pymilvus import AsyncMilvusClient, connections +from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_stores import VectorStore from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter +from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter +from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter from llama_stack.providers.utils.kvstore import register_kvstore_backends EMBEDDING_DIMENSION = 768 COLLECTION_PREFIX = "test_collection" +MILVUS_ALIAS = "test_milvus" @pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"]) @@ -141,7 +147,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): await index.initialize() index.db_path = db_path yield index - index.delete() + await index.delete() @pytest.fixture @@ -170,6 +176,48 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf await adapter.shutdown() +@pytest.fixture(scope="session") +def milvus_vec_db_path(tmp_path_factory): + db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db") + return db_path + + +@pytest.fixture +async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): + client = AsyncMilvusClient(milvus_vec_db_path) + name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" + connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path) + index = MilvusIndex(client, name, consistency_level="Strong") + index.db_path = milvus_vec_db_path + yield index + # Proper cleanup: close the async client + await client.close() + + +@pytest.fixture +async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): + config = MilvusVectorIOConfig( + db_path=milvus_vec_db_path, + kvstore=SqliteKVStoreConfig(), + ) + adapter = MilvusVectorIOAdapter( + config=config, + inference_api=mock_inference_api, + files_api=None, + ) + await adapter.initialize() + await adapter.register_vector_db( + VectorDB( + identifier=adapter.metadata_collection_name, + provider_id="test_provider", + embedding_model="test_model", + embedding_dimension=128, + ) + ) + yield adapter + await adapter.shutdown() + + @pytest.fixture def faiss_vec_db_path(tmp_path_factory): db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db") diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py new file mode 100644 index 000000000..6e8817366 --- /dev/null +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -0,0 +1,339 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock, patch + +import numpy as np +import pytest + +from llama_stack.apis.vector_io import QueryChunksResponse + +# Mock the entire pymilvus module +pymilvus_mock = MagicMock() +pymilvus_mock.DataType = MagicMock() +pymilvus_mock.AsyncMilvusClient = MagicMock +pymilvus_mock.RRFRanker = MagicMock +pymilvus_mock.WeightedRanker = MagicMock +pymilvus_mock.AnnSearchRequest = MagicMock + +# Apply the mock before importing MilvusIndex +with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): + from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex + +# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_milvus.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + +MILVUS_PROVIDER = "milvus" + + +@pytest.fixture +async def mock_milvus_client() -> MagicMock: + """Create a mock Milvus client with common method behaviors.""" + client = MagicMock() + + # Mock async collection operations + client.has_collection = AsyncMock(return_value=False) # Initially no collection + client.create_collection = AsyncMock(return_value=None) + client.drop_collection = AsyncMock(return_value=None) + + # Mock async insert operation + client.insert = AsyncMock(return_value={"insert_count": 10}) + + # Mock async search operation - return mock results (data should be dict, not JSON string) + client.search = AsyncMock( + return_value=[ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + ) + + # Mock async query operation for keyword search (data should be dict, not JSON string) + client.query = AsyncMock( + return_value=[ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, + }, + ] + ) + + # Mock async hybrid_search operation + client.hybrid_search = AsyncMock(return_value=[]) + + # Mock async delete operation + client.delete = AsyncMock(return_value=None) + + # Mock async close operation + client.close = AsyncMock(return_value=None) + + return client + + +@pytest.fixture +async def milvus_index(mock_milvus_client): + """Create a MilvusIndex with mocked client.""" + index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") + yield index + # No real cleanup needed since we're using mocks + + +async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + # Setup: collection doesn't exist initially, then exists after creation + mock_milvus_client.has_collection.side_effect = [False, True] + + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Verify collection was created and data was inserted + mock_milvus_client.create_collection.assert_called_once() + mock_milvus_client.insert.assert_called_once() + + # Verify the insert call had the right number of chunks + insert_call = mock_milvus_client.insert.call_args + assert len(insert_call[1]["data"]) == len(sample_chunks) + + +async def test_query_chunks_vector( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + # Setup: Add chunks first + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test vector search + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + mock_milvus_client.search.assert_called_once() + + +async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Test keyword search + query_string = "Sentence 5" + response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + + +async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): + """Test that when BM25 search fails, the system falls back to simple text search.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Force BM25 search to fail + mock_milvus_client.search.side_effect = Exception("BM25 search not available") + + # Mock simple text search results + mock_milvus_client.query.return_value = [ + { + "chunk_id": "chunk1", + "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}}, + }, + { + "chunk_id": "chunk2", + "chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}}, + }, + ] + + # Test keyword search that should fall back to simple text search + query_string = "Python" + response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0) + + # Verify response structure + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) > 0, "Fallback search should return results" + + # Verify that simple text search was used (query method called instead of search) + mock_milvus_client.query.assert_called_once() + mock_milvus_client.search.assert_called_once() # Called once but failed + + # Verify the query uses parameterized filter with filter_params + query_call_args = mock_milvus_client.query.call_args + assert "filter" in query_call_args[1], "Query should include filter for text search" + assert "filter_params" in query_call_args[1], "Query should use parameterized filter" + assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term" + + # Verify all returned chunks have score 1.0 (simple binary scoring) + assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" + + +async def test_delete_collection(milvus_index, mock_milvus_client): + # Test collection deletion + mock_milvus_client.has_collection.return_value = True + + await milvus_index.delete() + + mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name) + + +async def test_query_hybrid_search_rrf( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with RRF reranker.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Test hybrid search with RRF reranker + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=2, + score_threshold=0.0, + reranker_type="rrf", + reranker_params={"impact_factor": 60.0}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + # Verify hybrid search was called with correct parameters + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + + # Check that the request contains both vector and BM25 search requests + reqs = call_args[1]["reqs"] + assert len(reqs) == 2 + assert reqs[0].anns_field == "vector" + assert reqs[1].anns_field == "sparse" + ranker = call_args[1]["ranker"] + assert ranker is not None + + +async def test_query_hybrid_search_weighted( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with weighted reranker.""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + + # Test hybrid search with weighted reranker + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=2, + score_threshold=0.0, + reranker_type="weighted", + reranker_params={"alpha": 0.7}, + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 2 + assert len(response.scores) == 2 + + # Verify hybrid search was called with correct parameters + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + ranker = call_args[1]["ranker"] + assert ranker is not None + + +async def test_query_hybrid_search_default_rrf( + milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client +): + """Test hybrid search with default RRF reranker (no reranker_type specified).""" + mock_milvus_client.has_collection.return_value = True + await milvus_index.add_chunks(sample_chunks, sample_embeddings) + + # Mock hybrid search results + mock_milvus_client.hybrid_search.return_value = [ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + ] + ] + + # Test hybrid search with default reranker (should be RRF) + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) + query_string = "test query" + response = await milvus_index.query_hybrid( + embedding=query_embedding, + query_string=query_string, + k=1, + score_threshold=0.0, + reranker_type="unknown_type", # Should default to RRF + reranker_params=None, # Should use default impact_factor + ) + + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == 1 + + # Verify hybrid search was called with RRF reranker + mock_milvus_client.hybrid_search.assert_called_once() + call_args = mock_milvus_client.hybrid_search.call_args + ranker = call_args[1]["ranker"] + assert ranker is not None diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 121623e1b..2143f5ccc 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -48,12 +48,12 @@ async def test_initialize_index(vector_index): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): - vector_index.delete() - vector_index.initialize() + await vector_index.delete() + await vector_index.initialize() await vector_index.add_chunks(sample_chunks, sample_embeddings) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content - vector_index.delete() + await vector_index.delete() async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):