feat(client): migrate MilvusClient to AsyncMilvusClient

The commit makes the follwing changes.
-  Import statements updated: MilvusClient → AsyncMilvusClient
-  Removed asyncio.to_thread() wrappers: All Milvus operations now use native async/await
-  Test compatibility: Mock objects and fixtures updated to work with AsyncMilvusClient

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-09-08 17:19:42 +02:00
parent 0e27016cf2
commit 142bd248e7
4 changed files with 76 additions and 66 deletions

View file

@ -4,12 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
@ -48,7 +47,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
@ -61,15 +64,15 @@ class MilvusIndex(EmbeddingIndex):
pass
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
if await self.client.has_collection(self.collection_name):
await self.client.drop_collection(collection_name=self.collection_name)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
if not await self.client.has_collection(self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
@ -123,8 +126,7 @@ class MilvusIndex(EmbeddingIndex):
)
schema.add_function(bm25_function)
await asyncio.to_thread(
self.client.create_collection,
await self.client.create_collection(
self.collection_name,
schema=schema,
index_params=index_params,
@ -143,8 +145,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
await self.client.insert(
self.collection_name,
data=data,
)
@ -153,8 +154,7 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[embedding],
anns_field="vector",
@ -177,8 +177,7 @@ class MilvusIndex(EmbeddingIndex):
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
@ -219,8 +218,7 @@ class MilvusIndex(EmbeddingIndex):
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
search_res = await self.client.query(
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
@ -267,8 +265,7 @@ class MilvusIndex(EmbeddingIndex):
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
rerank = RRFRanker(impact_factor)
search_res = await asyncio.to_thread(
self.client.hybrid_search,
search_res = await self.client.hybrid_search(
collection_name=self.collection_name,
reqs=search_requests,
ranker=rerank,
@ -294,9 +291,7 @@ class MilvusIndex(EmbeddingIndex):
try:
# Use IN clause with square brackets and single quotes for VARCHAR field
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
await asyncio.to_thread(
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
)
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
except Exception as e:
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
raise
@ -340,17 +335,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
self.client = AsyncMilvusClient(uri=uri)
# Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
await self.client.close()
async def register_vector_db(
self,