diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 8ca9212bc..1949d293d 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import hashlib import logging import os @@ -35,15 +36,16 @@ class MilvusIndex(EmbeddingIndex): self.consistency_level = consistency_level async def delete(self): - if self.client.has_collection(self.collection_name): - self.client.drop_collection(collection_name=self.collection_name) + if await asyncio.to_thread(self.client.has_collection, self.collection_name): + await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not self.client.has_collection(self.collection_name): - self.client.create_collection( + if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + await asyncio.to_thread( + self.client.create_collection, self.collection_name, dimension=len(embeddings[0]), auto_id=True, @@ -62,7 +64,8 @@ class MilvusIndex(EmbeddingIndex): } ) try: - self.client.insert( + await asyncio.to_thread( + self.client.insert, self.collection_name, data=data, ) @@ -71,7 +74,8 @@ class MilvusIndex(EmbeddingIndex): raise e async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = self.client.search( + search_res = await asyncio.to_thread( + self.client.search, collection_name=self.collection_name, data=[embedding], limit=k,