diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index e07e8ff12..2e9ed9ced 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import os from typing import Any from numpy.typing import NDArray -from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker +from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files.files import Files @@ -48,12 +47,18 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class MilvusIndex(EmbeddingIndex): def __init__( - self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None + self, + client: AsyncMilvusClient, + collection_name: str, + consistency_level="Strong", + kvstore: KVStore | None = None, + parent_adapter=None, ): self.client = client self.collection_name = sanitize_collection_name(collection_name) self.consistency_level = consistency_level self.kvstore = kvstore + self._parent_adapter = parent_adapter async def initialize(self): # MilvusIndex does not require explicit initialization @@ -61,15 +66,39 @@ class MilvusIndex(EmbeddingIndex): pass async def delete(self): - if await asyncio.to_thread(self.client.has_collection, self.collection_name): - await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) + try: + collections = await self.client.list_collections() + if self.collection_name in collections: + await self.client.drop_collection(collection_name=self.collection_name) + except Exception as e: + logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}") async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + try: + collections = await self.client.list_collections() + collection_exists = self.collection_name in collections + except Exception as e: + logger.error(f"Failed to check collection existence: {self.collection_name} ({e})") + # If it's an event loop issue, try to recreate the client + if "attached to a different loop" in str(e): + logger.warning("Recreating client due to event loop issue") + + if hasattr(self, "_parent_adapter"): + await self._parent_adapter._recreate_client() + collections = await self.client.list_collections() + collection_exists = self.collection_name in collections + else: + # Assume collection doesn't exist if we can't check + collection_exists = False + else: + # Assume collection doesn't exist if we can't check due to other issues + collection_exists = False + + if not collection_exists: logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() @@ -123,13 +152,16 @@ class MilvusIndex(EmbeddingIndex): ) schema.add_function(bm25_function) - await asyncio.to_thread( - self.client.create_collection, - self.collection_name, - schema=schema, - index_params=index_params, - consistency_level=self.consistency_level, - ) + try: + await self.client.create_collection( + self.collection_name, + schema=schema, + index_params=index_params, + consistency_level=self.consistency_level, + ) + except Exception as e: + logger.error(f"Failed to create collection {self.collection_name}: {e}") + raise e data = [] for chunk, embedding in zip(chunks, embeddings, strict=False): @@ -143,8 +175,7 @@ class MilvusIndex(EmbeddingIndex): } ) try: - await asyncio.to_thread( - self.client.insert, + await self.client.insert( self.collection_name, data=data, ) @@ -153,8 +184,7 @@ class MilvusIndex(EmbeddingIndex): raise e async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[embedding], anns_field="vector", @@ -177,8 +207,7 @@ class MilvusIndex(EmbeddingIndex): """ try: # Use Milvus's built-in BM25 search - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[query_string], # Raw text query anns_field="sparse", # Use sparse field for BM25 @@ -219,8 +248,7 @@ class MilvusIndex(EmbeddingIndex): Fallback to simple text search when BM25 search is not available. """ # Simple text search using content field - search_res = await asyncio.to_thread( - self.client.query, + search_res = await self.client.query( collection_name=self.collection_name, filter='content like "%{content}%"', filter_params={"content": query_string}, @@ -267,8 +295,7 @@ class MilvusIndex(EmbeddingIndex): impact_factor = (reranker_params or {}).get("impact_factor", 60.0) rerank = RRFRanker(impact_factor) - search_res = await asyncio.to_thread( - self.client.hybrid_search, + search_res = await self.client.hybrid_search( collection_name=self.collection_name, reqs=search_requests, ranker=rerank, @@ -294,9 +321,7 @@ class MilvusIndex(EmbeddingIndex): try: # Use IN clause with square brackets and single quotes for VARCHAR field chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) - await asyncio.to_thread( - self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" - ) + await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]") except Exception as e: logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise @@ -321,6 +346,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) + + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Connecting to Milvus server at {self.config.uri}") + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = AsyncMilvusClient(uri=uri) + start_key = VECTOR_DBS_PREFIX end_key = f"{VECTOR_DBS_PREFIX}\xff" stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) @@ -334,23 +368,38 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP collection_name=vector_db.identifier, consistency_level=self.config.consistency_level, kvstore=self.kvstore, + parent_adapter=self, ), inference_api=self.inference_api, ) self.cache[vector_db.identifier] = index - if isinstance(self.config, RemoteMilvusVectorIOConfig): - logger.info(f"Connecting to Milvus server at {self.config.uri}") - self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) - else: - logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") - uri = os.path.expanduser(self.config.db_path) - self.client = MilvusClient(uri=uri) # Load existing OpenAI vector stores into the in-memory cache await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - self.client.close() + if self.client: + await self.client.close() + + async def _recreate_client(self) -> None: + """Recreate the AsyncMilvusClient when event loop issues occur""" + try: + if self.client: + await self.client.close() + except Exception as e: + logger.warning(f"Error closing old client: {e}") + + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Recreating connection to Milvus server at {self.config.uri}") + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = AsyncMilvusClient(uri=uri) + + for index_wrapper in self.cache.values(): + if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"): + index_wrapper.index.client = self.client async def register_vector_db( self, @@ -362,7 +411,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP consistency_level = "Strong" index = VectorDBWithIndex( vector_db=vector_db, - index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), + index=MilvusIndex( + client=self.client, + collection_name=vector_db.identifier, + consistency_level=consistency_level, + parent_adapter=self, + ), inference_api=self.inference_api, ) @@ -381,7 +435,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = VectorDBWithIndex( vector_db=vector_db, - index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), + index=MilvusIndex( + client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self + ), inference_api=self.inference_api, ) self.cache[vector_db_id] = index diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 70ace695e..ce76c2dff 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest from chromadb import PersistentClient -from pymilvus import MilvusClient, connections +from pymilvus import AsyncMilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse @@ -141,7 +141,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): await index.initialize() index.db_path = db_path yield index - index.delete() + await index.delete() @pytest.fixture @@ -178,13 +178,15 @@ def milvus_vec_db_path(tmp_path_factory): @pytest.fixture async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): - client = MilvusClient(milvus_vec_db_path) + client = AsyncMilvusClient(milvus_vec_db_path) name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path) index = MilvusIndex(client, name, consistency_level="Strong") index.db_path = milvus_vec_db_path yield index + await client.close() + @pytest.fixture async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api): diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index ca5f45fa2..7020900d7 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest @@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse # Mock the entire pymilvus module pymilvus_mock = MagicMock() pymilvus_mock.DataType = MagicMock() -pymilvus_mock.MilvusClient = MagicMock +pymilvus_mock.AsyncMilvusClient = MagicMock pymilvus_mock.RRFRanker = MagicMock pymilvus_mock.WeightedRanker = MagicMock pymilvus_mock.AnnSearchRequest = MagicMock @@ -40,48 +40,55 @@ async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() - # Mock collection operations - client.has_collection.return_value = False # Initially no collection - client.create_collection.return_value = None - client.drop_collection.return_value = None + client.list_collections = AsyncMock(return_value=[]) # Initially no collections + client.create_collection = AsyncMock(return_value=None) + client.drop_collection = AsyncMock(return_value=None) - # Mock insert operation - client.insert.return_value = {"insert_count": 10} + client.insert = AsyncMock(return_value={"insert_count": 10}) - # Mock search operation - return mock results (data should be dict, not JSON string) - client.search.return_value = [ - [ + client.search = AsyncMock( + return_value=[ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + ) + + # Mock async query operation for keyword search (data should be dict, not JSON string) + client.query = AsyncMock( + return_value=[ { - "id": 0, - "distance": 0.1, - "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, }, { - "id": 1, - "distance": 0.2, - "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, }, ] - ] + ) - # Mock query operation for keyword search (data should be dict, not JSON string) - client.query.return_value = [ - { - "chunk_id": "chunk1", - "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, - "score": 0.9, - }, - { - "chunk_id": "chunk2", - "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, - "score": 0.8, - }, - { - "chunk_id": "chunk3", - "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, - "score": 0.7, - }, - ] + client.hybrid_search = AsyncMock(return_value=[]) + + client.delete = AsyncMock(return_value=None) + + client.close = AsyncMock(return_value=None) return client @@ -96,7 +103,7 @@ async def milvus_index(mock_milvus_client): async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): # Setup: collection doesn't exist initially, then exists after creation - mock_milvus_client.has_collection.side_effect = [False, True] + mock_milvus_client.list_collections.side_effect = [[], ["test_collection"]] await milvus_index.add_chunks(sample_chunks, sample_embeddings) @@ -113,7 +120,7 @@ async def test_query_chunks_vector( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): # Setup: Add chunks first - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Test vector search @@ -126,7 +133,7 @@ async def test_query_chunks_vector( async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Test keyword search @@ -139,7 +146,7 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): """Test that when BM25 search fails, the system falls back to simple text search.""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Force BM25 search to fail @@ -181,7 +188,7 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl async def test_delete_collection(milvus_index, mock_milvus_client): # Test collection deletion - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.delete() @@ -192,7 +199,7 @@ async def test_query_hybrid_search_rrf( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): """Test hybrid search with RRF reranker.""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Mock hybrid search results @@ -244,7 +251,7 @@ async def test_query_hybrid_search_weighted( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): """Test hybrid search with weighted reranker.""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Mock hybrid search results @@ -290,7 +297,7 @@ async def test_query_hybrid_search_default_rrf( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): """Test hybrid search with default RRF reranker (no reranker_type specified).""" - mock_milvus_client.has_collection.return_value = True + mock_milvus_client.list_collections.return_value = ["test_collection"] await milvus_index.add_chunks(sample_chunks, sample_embeddings) # Mock hybrid search results diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 98889f38e..9669d3922 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -30,12 +30,12 @@ async def test_initialize_index(vector_index): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): - vector_index.delete() - vector_index.initialize() + await vector_index.delete() + await vector_index.initialize() await vector_index.add_chunks(sample_chunks, sample_embeddings) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content - vector_index.delete() + await vector_index.delete() async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):