diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 5e217bb55..383fa517a 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -52,11 +52,13 @@ class MilvusIndex(EmbeddingIndex): collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None, + parent_adapter=None, ): self.client = client self.collection_name = sanitize_collection_name(collection_name) self.consistency_level = consistency_level self.kvstore = kvstore + self._parent_adapter = parent_adapter async def initialize(self): # MilvusIndex does not require explicit initialization @@ -64,15 +66,36 @@ class MilvusIndex(EmbeddingIndex): pass async def delete(self): - if await self.client.has_collection(self.collection_name): - await self.client.drop_collection(collection_name=self.collection_name) + try: + if await self.client.has_collection(self.collection_name): + await self.client.drop_collection(collection_name=self.collection_name) + except Exception as e: + logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}") async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not await self.client.has_collection(self.collection_name): + try: + collection_exists = await self.client.has_collection(self.collection_name) + except Exception as e: + logger.error(f"Failed to check collection existence: {self.collection_name} ({e})") + # If it's an event loop issue, try to recreate the client + if "attached to a different loop" in str(e): + logger.warning("Recreating client due to event loop issue") + + if hasattr(self, "_parent_adapter"): + await self._parent_adapter._recreate_client() + collection_exists = await self.client.has_collection(self.collection_name) + else: + # Assume collection doesn't exist if we can't check + collection_exists = False + else: + # Assume collection doesn't exist if we can't check due to other issues + collection_exists = False + + if not collection_exists: logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() @@ -126,12 +149,16 @@ class MilvusIndex(EmbeddingIndex): ) schema.add_function(bm25_function) - await self.client.create_collection( - self.collection_name, - schema=schema, - index_params=index_params, - consistency_level=self.consistency_level, - ) + try: + await self.client.create_collection( + self.collection_name, + schema=schema, + index_params=index_params, + consistency_level=self.consistency_level, + ) + except Exception as e: + logger.error(f"Failed to create collection {self.collection_name}: {e}") + raise e data = [] for chunk, embedding in zip(chunks, embeddings, strict=False): @@ -316,6 +343,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) + + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Connecting to Milvus server at {self.config.uri}") + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = AsyncMilvusClient(uri=uri) + start_key = VECTOR_DBS_PREFIX end_key = f"{VECTOR_DBS_PREFIX}\xff" stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) @@ -329,23 +365,38 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP collection_name=vector_db.identifier, consistency_level=self.config.consistency_level, kvstore=self.kvstore, + parent_adapter=self, ), inference_api=self.inference_api, ) self.cache[vector_db.identifier] = index - if isinstance(self.config, RemoteMilvusVectorIOConfig): - logger.info(f"Connecting to Milvus server at {self.config.uri}") - self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) - else: - logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") - uri = os.path.expanduser(self.config.db_path) - self.client = AsyncMilvusClient(uri=uri) # Load existing OpenAI vector stores into the in-memory cache await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - await self.client.close() + if self.client: + await self.client.close() + + async def _recreate_client(self) -> None: + """Recreate the AsyncMilvusClient when event loop issues occur""" + try: + if self.client: + await self.client.close() + except Exception as e: + logger.warning(f"Error closing old client: {e}") + + if isinstance(self.config, RemoteMilvusVectorIOConfig): + logger.info(f"Recreating connection to Milvus server at {self.config.uri}") + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) + else: + logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}") + uri = os.path.expanduser(self.config.db_path) + self.client = AsyncMilvusClient(uri=uri) + + for index_wrapper in self.cache.values(): + if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"): + index_wrapper.index.client = self.client async def register_vector_db( self, @@ -357,7 +408,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP consistency_level = "Strong" index = VectorDBWithIndex( vector_db=vector_db, - index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), + index=MilvusIndex( + client=self.client, + collection_name=vector_db.identifier, + consistency_level=consistency_level, + parent_adapter=self, + ), inference_api=self.inference_api, ) @@ -376,7 +432,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = VectorDBWithIndex( vector_db=vector_db, - index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), + index=MilvusIndex( + client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self + ), inference_api=self.inference_api, ) self.cache[vector_db_id] = index diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 04bac71c2..25374e617 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -40,15 +40,12 @@ async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() - # Mock async collection operations client.has_collection = AsyncMock(return_value=False) # Initially no collection client.create_collection = AsyncMock(return_value=None) client.drop_collection = AsyncMock(return_value=None) - # Mock async insert operation client.insert = AsyncMock(return_value={"insert_count": 10}) - # Mock async search operation client.search = AsyncMock( return_value=[ [