feat(client): migrate MilvusClient to AsyncMilvusClient

The commit makes the follwing changes.
-  Import statements updated: MilvusClient → AsyncMilvusClient
-  Removed asyncio.to_thread() wrappers: All Milvus operations now use native async/await
-  Test compatibility: Mock objects and fixtures updated to work with AsyncMilvusClient

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-09-08 17:19:42 +02:00
parent 0e27016cf2
commit 142bd248e7
4 changed files with 76 additions and 66 deletions

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
@ -48,7 +47,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
@ -61,15 +64,15 @@ class MilvusIndex(EmbeddingIndex):
pass
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
if await self.client.has_collection(self.collection_name):
await self.client.drop_collection(collection_name=self.collection_name)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
if not await self.client.has_collection(self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
@ -123,8 +126,7 @@ class MilvusIndex(EmbeddingIndex):
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
await self.client.create_collection(
self.collection_name,
schema=schema,
index_params=index_params,
@ -143,8 +145,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
await self.client.insert(
self.collection_name,
data=data,
)
@ -153,8 +154,7 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
@ -177,8 +177,7 @@ class MilvusIndex(EmbeddingIndex):
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
@ -219,8 +218,7 @@ class MilvusIndex(EmbeddingIndex):
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
search_res = await self.client.query(
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
@ -267,8 +265,7 @@ class MilvusIndex(EmbeddingIndex):
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread(
self.client.hybrid_search,
search_res = await self.client.hybrid_search(
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
@ -294,9 +291,7 @@ class MilvusIndex(EmbeddingIndex):
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
except Exception as e:
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -340,17 +335,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
self.client = AsyncMilvusClient(uri=uri)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
await self.client.close()
async def register_vector_db(
self,

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from pymilvus import AsyncMilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
@ -139,7 +139,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
await index.initialize()
index.db_path = db_path
yield index
index.delete()
await index.delete()
@pytest.fixture
@ -176,12 +176,14 @@ def milvus_vec_db_path(tmp_path_factory):
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
client = AsyncMilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
# Proper cleanup: close the async client
await client.close()
@pytest.fixture

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
pymilvus_mock.AsyncMilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
@ -40,48 +40,61 @@ async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock async collection operations
client.has_collection = AsyncMock(return_value=False) # Initially no collection
client.create_collection = AsyncMock(return_value=None)
client.drop_collection = AsyncMock(return_value=None)
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock async insert operation
client.insert = AsyncMock(return_value={"insert_count": 10})
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
# Mock async search operation - return mock results (data should be dict, not JSON string)
client.search = AsyncMock(
return_value=[
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
)
# Mock async query operation for keyword search (data should be dict, not JSON string)
client.query = AsyncMock(
return_value=[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
]
)
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
# Mock async hybrid_search operation
client.hybrid_search = AsyncMock(return_value=[])
# Mock async delete operation
client.delete = AsyncMock(return_value=None)
# Mock async close operation
client.close = AsyncMock(return_value=None)
return client

View file

@ -30,12 +30,12 @@ async def test_initialize_index(vector_index):
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
vector_index.delete()
vector_index.initialize()
await vector_index.delete()
await vector_index.initialize()
await vector_index.add_chunks(sample_chunks, sample_embeddings)
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
await vector_index.delete()
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):