From 142bd248e73260a8be48777fedb9fb8653dadc21 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 8 Sep 2025 17:19:42 +0200 Subject: [PATCH] feat(client): migrate MilvusClient to AsyncMilvusClient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The commit makes the follwing changes. - Import statements updated: MilvusClient → AsyncMilvusClient - Removed asyncio.to_thread() wrappers: All Milvus operations now use native async/await - Test compatibility: Mock objects and fixtures updated to work with AsyncMilvusClient Signed-off-by: Mustafa Elbehery --- .../remote/vector_io/milvus/milvus.py | 43 +++++----- tests/unit/providers/vector_io/conftest.py | 8 +- .../providers/vector_io/remote/test_milvus.py | 85 +++++++++++-------- .../test_vector_io_openai_vector_stores.py | 6 +- 4 files changed, 76 insertions(+), 66 deletions(-) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index e07e8ff12..5e217bb55 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -4,12 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio import os from typing import Any from numpy.typing import NDArray -from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker +from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.files.files import Files @@ -48,7 +47,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class MilvusIndex(EmbeddingIndex): def __init__( - self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None + self, + client: AsyncMilvusClient, + collection_name: str, + consistency_level="Strong", + kvstore: KVStore | None = None, ): self.client = client self.collection_name = sanitize_collection_name(collection_name) @@ -61,15 +64,15 @@ class MilvusIndex(EmbeddingIndex): pass async def delete(self): - if await asyncio.to_thread(self.client.has_collection, self.collection_name): - await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) + if await self.client.has_collection(self.collection_name): + await self.client.drop_collection(collection_name=self.collection_name) async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + if not await self.client.has_collection(self.collection_name): logger.info(f"Creating new collection {self.collection_name} with nullable sparse field") # Create schema for vector search schema = self.client.create_schema() @@ -123,8 +126,7 @@ class MilvusIndex(EmbeddingIndex): ) schema.add_function(bm25_function) - await asyncio.to_thread( - self.client.create_collection, + await self.client.create_collection( self.collection_name, schema=schema, index_params=index_params, @@ -143,8 +145,7 @@ class MilvusIndex(EmbeddingIndex): } ) try: - await asyncio.to_thread( - self.client.insert, + await self.client.insert( self.collection_name, data=data, ) @@ -153,8 +154,7 @@ class MilvusIndex(EmbeddingIndex): raise e async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[embedding], anns_field="vector", @@ -177,8 +177,7 @@ class MilvusIndex(EmbeddingIndex): """ try: # Use Milvus's built-in BM25 search - search_res = await asyncio.to_thread( - self.client.search, + search_res = await self.client.search( collection_name=self.collection_name, data=[query_string], # Raw text query anns_field="sparse", # Use sparse field for BM25 @@ -219,8 +218,7 @@ class MilvusIndex(EmbeddingIndex): Fallback to simple text search when BM25 search is not available. """ # Simple text search using content field - search_res = await asyncio.to_thread( - self.client.query, + search_res = await self.client.query( collection_name=self.collection_name, filter='content like "%{content}%"', filter_params={"content": query_string}, @@ -267,8 +265,7 @@ class MilvusIndex(EmbeddingIndex): impact_factor = (reranker_params or {}).get("impact_factor", 60.0) rerank = RRFRanker(impact_factor) - search_res = await asyncio.to_thread( - self.client.hybrid_search, + search_res = await self.client.hybrid_search( collection_name=self.collection_name, reqs=search_requests, ranker=rerank, @@ -294,9 +291,7 @@ class MilvusIndex(EmbeddingIndex): try: # Use IN clause with square brackets and single quotes for VARCHAR field chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids) - await asyncio.to_thread( - self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]" - ) + await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]") except Exception as e: logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}") raise @@ -340,17 +335,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.cache[vector_db.identifier] = index if isinstance(self.config, RemoteMilvusVectorIOConfig): logger.info(f"Connecting to Milvus server at {self.config.uri}") - self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) + self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True)) else: logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") uri = os.path.expanduser(self.config.db_path) - self.client = MilvusClient(uri=uri) + self.client = AsyncMilvusClient(uri=uri) # Load existing OpenAI vector stores into the in-memory cache await self.initialize_openai_vector_stores() async def shutdown(self) -> None: - self.client.close() + await self.client.close() async def register_vector_db( self, diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py index 91bddd037..b8ea7a203 100644 --- a/tests/unit/providers/vector_io/conftest.py +++ b/tests/unit/providers/vector_io/conftest.py @@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest from chromadb import PersistentClient -from pymilvus import MilvusClient, connections +from pymilvus import AsyncMilvusClient, connections from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse @@ -139,7 +139,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory): await index.initialize() index.db_path = db_path yield index - index.delete() + await index.delete() @pytest.fixture @@ -176,12 +176,14 @@ def milvus_vec_db_path(tmp_path_factory): @pytest.fixture async def milvus_vec_index(milvus_vec_db_path, embedding_dimension): - client = MilvusClient(milvus_vec_db_path) + client = AsyncMilvusClient(milvus_vec_db_path) name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path) index = MilvusIndex(client, name, consistency_level="Strong") index.db_path = milvus_vec_db_path yield index + # Proper cleanup: close the async client + await client.close() @pytest.fixture diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index ca5f45fa2..6e8817366 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest @@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse # Mock the entire pymilvus module pymilvus_mock = MagicMock() pymilvus_mock.DataType = MagicMock() -pymilvus_mock.MilvusClient = MagicMock +pymilvus_mock.AsyncMilvusClient = MagicMock pymilvus_mock.RRFRanker = MagicMock pymilvus_mock.WeightedRanker = MagicMock pymilvus_mock.AnnSearchRequest = MagicMock @@ -40,48 +40,61 @@ async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() - # Mock collection operations - client.has_collection.return_value = False # Initially no collection - client.create_collection.return_value = None - client.drop_collection.return_value = None + # Mock async collection operations + client.has_collection = AsyncMock(return_value=False) # Initially no collection + client.create_collection = AsyncMock(return_value=None) + client.drop_collection = AsyncMock(return_value=None) - # Mock insert operation - client.insert.return_value = {"insert_count": 10} + # Mock async insert operation + client.insert = AsyncMock(return_value={"insert_count": 10}) - # Mock search operation - return mock results (data should be dict, not JSON string) - client.search.return_value = [ - [ + # Mock async search operation - return mock results (data should be dict, not JSON string) + client.search = AsyncMock( + return_value=[ + [ + { + "id": 0, + "distance": 0.1, + "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + }, + { + "id": 1, + "distance": 0.2, + "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + }, + ] + ] + ) + + # Mock async query operation for keyword search (data should be dict, not JSON string) + client.query = AsyncMock( + return_value=[ { - "id": 0, - "distance": 0.1, - "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}}, + "chunk_id": "chunk1", + "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, + "score": 0.9, }, { - "id": 1, - "distance": 0.2, - "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}}, + "chunk_id": "chunk2", + "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, + "score": 0.8, + }, + { + "chunk_id": "chunk3", + "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, + "score": 0.7, }, ] - ] + ) - # Mock query operation for keyword search (data should be dict, not JSON string) - client.query.return_value = [ - { - "chunk_id": "chunk1", - "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}, - "score": 0.9, - }, - { - "chunk_id": "chunk2", - "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}, - "score": 0.8, - }, - { - "chunk_id": "chunk3", - "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}}, - "score": 0.7, - }, - ] + # Mock async hybrid_search operation + client.hybrid_search = AsyncMock(return_value=[]) + + # Mock async delete operation + client.delete = AsyncMock(return_value=None) + + # Mock async close operation + client.close = AsyncMock(return_value=None) return client diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 98889f38e..9669d3922 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -30,12 +30,12 @@ async def test_initialize_index(vector_index): async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings): - vector_index.delete() - vector_index.initialize() + await vector_index.delete() + await vector_index.initialize() await vector_index.add_chunks(sample_chunks, sample_embeddings) resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1) assert resp.chunks[0].content == sample_chunks[0].content - vector_index.delete() + await vector_index.delete() async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):