From f75ec332b5bc080536fcc0953b6c15cfa299284c Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Wed, 9 Jul 2025 15:44:39 +0200 Subject: [PATCH] feat: migrate from MilvusClient to AsyncMilvusClient This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 Signed-off-by: Mustafa Elbehery --- .../remote/vector_io/milvus/milvus.py | 204 +++++++++--------- .../test_vector_io_openai_vector_stores.py | 6 +- 2 files changed, 103 insertions(+), 107 deletions(-) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 1f65e580e..87262699c 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import json import logging import os @@ -12,7 +11,7 @@ import re from typing import Any from numpy.typing import NDArray -from pymilvus import DataType, MilvusClient +from pymilvus import AsyncMilvusClient, DataType from llama_stack.apis.files.files import Files from llama_stack.apis.inference import Inference, InterleavedContent @@ -54,7 +53,11 @@ def sanitize_collection_name(name: str) -> str: class MilvusIndex(EmbeddingIndex): def __init__( - self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None + self, + client: AsyncMilvusClient, + collection_name: str, + consistency_level="Strong", + kvstore: KVStore | None = None, ): self.client = client self.collection_name = sanitize_collection_name(collection_name) @@ -62,16 +65,15 @@ class MilvusIndex(EmbeddingIndex): self.kvstore = kvstore async def delete(self): - if await asyncio.to_thread(self.client.has_collection, self.collection_name): - await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) + if await self.client.collection_exists(self.collection_name): + await self.client.drop_collection(collection_name=self.collection_name) async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not await asyncio.to_thread(self.client.has_collection, self.collection_name): - await asyncio.to_thread( - self.client.create_collection, + if not await self.client.collection_exists(self.collection_name): + await self.client.create_collection( self.collection_name, dimension=len(embeddings[0]), auto_id=True, @@ -88,8 +90,7 @@ class MilvusIndex(EmbeddingIndex): } ) try: - await asyncio.to_thread( - self.client.insert, + await self.client.insert( self.collection_name, data=data, ) @@ -98,8 +99,7 @@ class MilvusIndex(EmbeddingIndex): raise e async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[embedding], limit=k, @@ -139,57 +139,68 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> None: self.config = config self.cache = {} - self.client = None + self.client: AsyncMilvusClient | None = None self.inference_api = inference_api self.files_api = files_api self.kvstore: KVStore | None = None - self.vector_db_store = None self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.metadata_collection_name = "openai_vector_stores_metadata" async def initialize(self) -> None: - self.kvstore = await kvstore_impl(self.config.kvstore) - start_key = VECTOR_DBS_PREFIX - end_key = f"{VECTOR_DBS_PREFIX}\xff" - stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) + if self.config.kvstore is not None: + self.kvstore = await kvstore_impl(self.config.kvstore) - for vector_db_data in stored_vector_dbs: - vector_db = VectorDB.model_validate_json(vector_db_data) - index = VectorDBWithIndex( - vector_db, - index=MilvusIndex( - client=self.client, - collection_name=vector_db.identifier, - consistency_level=self.config.consistency_level, - kvstore=self.kvstore, - ), - inference_api=self.inference_api, - ) - self.cache[vector_db.identifier] = index + # Initialize client first before using it if isinstance(self.config, RemoteMilvusVectorIOConfig): logger.info(f"Connecting to Milvus server at {self.config.uri}") - self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) else: logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") uri = os.path.expanduser(self.config.db_path) - self.client = MilvusClient(uri=uri) + self.client = AsyncMilvusClient(uri=uri) + + # Now load stored vector databases + if self.kvstore is not None: + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) + + for vector_db_data in stored_vector_dbs: + vector_db = VectorDB.model_validate_json(vector_db_data) + index = VectorDBWithIndex( + vector_db, + index=MilvusIndex( + client=self.client, + collection_name=vector_db.identifier, + consistency_level=self.config.consistency_level, + kvstore=self.kvstore, + ), + inference_api=self.inference_api, + ) + self.cache[vector_db.identifier] = index self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: - self.client.close() + if self.client is not None: + await self.client.close() - async def register_vector_db( - self, - vector_db: VectorDB, - ) -> None: + def _ensure_client_initialized(self) -> AsyncMilvusClient: + """Ensure the client is initialized and return it.""" + if self.client is None: + raise RuntimeError("Milvus client is not initialized. Call initialize() first.") + return self.client + + async def register_vector_db(self, vector_db: VectorDB) -> None: if isinstance(self.config, RemoteMilvusVectorIOConfig): consistency_level = self.config.consistency_level else: consistency_level = "Strong" index = VectorDBWithIndex( vector_db=vector_db, - index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), + index=MilvusIndex( + self._ensure_client_initialized(), vector_db.identifier, consistency_level=consistency_level + ), inference_api=self.inference_api, ) @@ -199,17 +210,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP if vector_db_id in self.cache: return self.cache[vector_db_id] - vector_db = await self.vector_db_store.get_vector_db(vector_db_id) - if not vector_db: - raise ValueError(f"Vector DB {vector_db_id} not found") - - index = VectorDBWithIndex( - vector_db=vector_db, - index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), - inference_api=self.inference_api, - ) - self.cache[vector_db_id] = index - return index + # Vector DB should be registered before use + raise ValueError(f"Vector DB {vector_db_id} not found. Please register it first.") async def unregister_vector_db(self, vector_db_id: str) -> None: if vector_db_id in self.cache: @@ -275,14 +277,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP ) -> None: """Save vector store file metadata to Milvus database.""" if store_id not in self.openai_vector_stores: - store_info = await self._load_openai_vector_stores(store_id) - if not store_info: + # Reload all vector stores to check if the store exists + self.openai_vector_stores = await self._load_openai_vector_stores() + if store_id not in self.openai_vector_stores: logger.error(f"OpenAI vector store {store_id} not found") raise ValueError(f"No vector store found with id {store_id}") try: - if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): - file_schema = MilvusClient.create_schema( + client = self._ensure_client_initialized() + + if not await client.collection_exists("openai_vector_store_files"): + file_schema = AsyncMilvusClient.create_schema( auto_id=False, enable_dynamic_field=True, description="Metadata for OpenAI vector store files", @@ -294,14 +299,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512) file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535) - await asyncio.to_thread( - self.client.create_collection, + await client.create_collection( collection_name="openai_vector_store_files", schema=file_schema, ) - if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): - content_schema = MilvusClient.create_schema( + if not await client.collection_exists("openai_vector_store_files_contents"): + content_schema = AsyncMilvusClient.create_schema( auto_id=False, enable_dynamic_field=True, description="Contents for OpenAI vector store files", @@ -314,8 +318,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512) content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535) - await asyncio.to_thread( - self.client.create_collection, + await client.create_collection( collection_name="openai_vector_store_files_contents", schema=content_schema, ) @@ -328,44 +331,44 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP "file_info": json.dumps(file_info), } ] - await asyncio.to_thread( - self.client.upsert, + await client.upsert( collection_name="openai_vector_store_files", data=file_data, ) # Save file contents - contents_data = [ - { - "chunk_id": content.get("chunk_metadata").get("chunk_id"), - "store_file_id": f"{store_id}_{file_id}", - "store_id": store_id, - "file_id": file_id, - "content": json.dumps(content), - } - for content in file_contents - ] - await asyncio.to_thread( - self.client.upsert, - collection_name="openai_vector_store_files_contents", - data=contents_data, - ) + contents_data = [] + for content in file_contents: + chunk_metadata = content.get("chunk_metadata") + if chunk_metadata and chunk_metadata.get("chunk_id"): + contents_data.append( + { + "chunk_id": chunk_metadata.get("chunk_id"), + "store_file_id": f"{store_id}_{file_id}", + "store_id": store_id, + "file_id": file_id, + "content": json.dumps(content), + } + ) + + if contents_data: + await client.upsert(collection_name="openai_vector_store_files_contents", data=contents_data) except Exception as e: logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}") + raise async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: """Load vector store file metadata from Milvus database.""" try: - if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): + client = self._ensure_client_initialized() + + if not await client.collection_exists("openai_vector_store_files"): return {} query_filter = f"store_file_id == '{store_id}_{file_id}'" - results = await asyncio.to_thread( - self.client.query, - collection_name="openai_vector_store_files", - filter=query_filter, - output_fields=["file_info"], + results = await client.query( + collection_name="openai_vector_store_files", filter=query_filter, output_fields=["file_info"] ) if results: @@ -382,7 +385,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: """Update vector store file metadata in Milvus database.""" try: - if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): + client = self._ensure_client_initialized() + + if not await client.collection_exists("openai_vector_store_files"): return file_data = [ @@ -393,11 +398,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP "file_info": json.dumps(file_info), } ] - await asyncio.to_thread( - self.client.upsert, - collection_name="openai_vector_store_files", - data=file_data, - ) + await client.upsert(collection_name="openai_vector_store_files", data=file_data) except Exception as e: logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}") raise @@ -405,14 +406,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: """Load vector store file contents from Milvus database.""" try: - if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): + client = self._ensure_client_initialized() + + if not await client.collection_exists("openai_vector_store_files_contents"): return [] query_filter = ( f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'" ) - results = await asyncio.to_thread( - self.client.query, + results = await client.query( collection_name="openai_vector_store_files_contents", filter=query_filter, output_fields=["chunk_id", "store_id", "file_id", "content"], @@ -433,21 +435,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: """Delete vector store file metadata from Milvus database.""" try: - if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): + client = self._ensure_client_initialized() + + if not await client.collection_exists("openai_vector_store_files"): return query_filter = f"store_file_id in ['{store_id}_{file_id}']" - await asyncio.to_thread( - self.client.delete, - collection_name="openai_vector_store_files", - filter=query_filter, - ) - if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): - await asyncio.to_thread( - self.client.delete, - collection_name="openai_vector_store_files_contents", - filter=query_filter, - ) + await client.delete(collection_name="openai_vector_store_files", filter=query_filter) + if await client.collection_exists("openai_vector_store_files_contents"): + await client.delete(collection_name="openai_vector_store_files_contents", filter=query_filter) except Exception as e: logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}") diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 0a109e833..dc9513487 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -11,7 +11,7 @@ from unittest.mock import AsyncMock import numpy as np import pytest import pytest_asyncio -from pymilvus import Collection, MilvusClient, connections +from pymilvus import AsyncMilvusClient, Collection, connections from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse @@ -52,7 +52,7 @@ async def unique_kvstore_config(tmp_path_factory): async def milvus_vec_index(embedding_dimension, tmp_path_factory): temp_dir = tmp_path_factory.getbasetemp() db_path = str(temp_dir / "test_milvus.db") - client = MilvusClient(db_path) + client = AsyncMilvusClient(uri=db_path) name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" connections.connect(alias=MILVUS_ALIAS, uri=db_path) index = MilvusIndex(client, name, consistency_level="Strong") @@ -148,7 +148,7 @@ async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_co await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data) assert milvus_vec_index.client is not None - assert isinstance(milvus_vec_index.client, MilvusClient) + assert isinstance(milvus_vec_index.client, AsyncMilvusClient) assert tmp_milvus_vec_adapter.cache is not None # registering a vector won't update the cache or openai_vector_store collection name assert (