fix(integration): init AsyncMilvusClient before MilvusIndex

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-09-08 22:00:59 +02:00
parent e7444c1d9b
commit 733d0c70fe
2 changed files with 77 additions and 22 deletions

View file

@ -52,11 +52,13 @@ class MilvusIndex(EmbeddingIndex):
collection_name: str, collection_name: str,
consistency_level="Strong", consistency_level="Strong",
kvstore: KVStore | None = None, kvstore: KVStore | None = None,
parent_adapter=None,
): ):
self.client = client self.client = client
self.collection_name = sanitize_collection_name(collection_name) self.collection_name = sanitize_collection_name(collection_name)
self.consistency_level = consistency_level self.consistency_level = consistency_level
self.kvstore = kvstore self.kvstore = kvstore
self._parent_adapter = parent_adapter
async def initialize(self): async def initialize(self):
# MilvusIndex does not require explicit initialization # MilvusIndex does not require explicit initialization
@ -64,15 +66,36 @@ class MilvusIndex(EmbeddingIndex):
pass pass
async def delete(self): async def delete(self):
if await self.client.has_collection(self.collection_name): try:
await self.client.drop_collection(collection_name=self.collection_name) if await self.client.has_collection(self.collection_name):
await self.client.drop_collection(collection_name=self.collection_name)
except Exception as e:
logger.warning(f"Failed to check or delete collection {self.collection_name}: {e}")
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not await self.client.has_collection(self.collection_name): try:
collection_exists = await self.client.has_collection(self.collection_name)
except Exception as e:
logger.error(f"Failed to check collection existence: {self.collection_name} ({e})")
# If it's an event loop issue, try to recreate the client
if "attached to a different loop" in str(e):
logger.warning("Recreating client due to event loop issue")
if hasattr(self, "_parent_adapter"):
await self._parent_adapter._recreate_client()
collection_exists = await self.client.has_collection(self.collection_name)
else:
# Assume collection doesn't exist if we can't check
collection_exists = False
else:
# Assume collection doesn't exist if we can't check due to other issues
collection_exists = False
if not collection_exists:
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search # Create schema for vector search
schema = self.client.create_schema() schema = self.client.create_schema()
@ -126,12 +149,16 @@ class MilvusIndex(EmbeddingIndex):
) )
schema.add_function(bm25_function) schema.add_function(bm25_function)
await self.client.create_collection( try:
self.collection_name, await self.client.create_collection(
schema=schema, self.collection_name,
index_params=index_params, schema=schema,
consistency_level=self.consistency_level, index_params=index_params,
) consistency_level=self.consistency_level,
)
except Exception as e:
logger.error(f"Failed to create collection {self.collection_name}: {e}")
raise e
data = [] data = []
for chunk, embedding in zip(chunks, embeddings, strict=False): for chunk, embedding in zip(chunks, embeddings, strict=False):
@ -316,6 +343,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
start_key = VECTOR_DBS_PREFIX start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff" end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
@ -329,23 +365,38 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
collection_name=vector_db.identifier, collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level, consistency_level=self.config.consistency_level,
kvstore=self.kvstore, kvstore=self.kvstore,
parent_adapter=self,
), ),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db.identifier] = index self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
# Load existing OpenAI vector stores into the in-memory cache # Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
await self.client.close() if self.client:
await self.client.close()
async def _recreate_client(self) -> None:
"""Recreate the AsyncMilvusClient when event loop issues occur"""
try:
if self.client:
await self.client.close()
except Exception as e:
logger.warning(f"Error closing old client: {e}")
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Recreating connection to Milvus server at {self.config.uri}")
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Recreating connection to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = AsyncMilvusClient(uri=uri)
for index_wrapper in self.cache.values():
if hasattr(index_wrapper, "index") and hasattr(index_wrapper.index, "client"):
index_wrapper.index.client = self.client
async def register_vector_db( async def register_vector_db(
self, self,
@ -357,7 +408,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
consistency_level = "Strong" consistency_level = "Strong"
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=consistency_level,
parent_adapter=self,
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
@ -376,7 +432,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), index=MilvusIndex(
client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore, parent_adapter=self
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
self.cache[vector_db_id] = index self.cache[vector_db_id] = index

View file

@ -40,15 +40,12 @@ async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors.""" """Create a mock Milvus client with common method behaviors."""
client = MagicMock() client = MagicMock()
# Mock async collection operations
client.has_collection = AsyncMock(return_value=False) # Initially no collection client.has_collection = AsyncMock(return_value=False) # Initially no collection
client.create_collection = AsyncMock(return_value=None) client.create_collection = AsyncMock(return_value=None)
client.drop_collection = AsyncMock(return_value=None) client.drop_collection = AsyncMock(return_value=None)
# Mock async insert operation
client.insert = AsyncMock(return_value={"insert_count": 10}) client.insert = AsyncMock(return_value={"insert_count": 10})
# Mock async search operation
client.search = AsyncMock( client.search = AsyncMock(
return_value=[ return_value=[
[ [