From ad9b8da7968ce2d29699675e9082d4010e4131e8 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Fri, 28 Mar 2025 21:39:29 -0400 Subject: [PATCH] chore: Updating Milvus Client calls to be non-blocking Signed-off-by: Francisco Javier Arceo --- .../providers/remote/vector_io/milvus/milvus.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 8ca9212bc..1949d293d 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import hashlib import logging import os @@ -35,15 +36,16 @@ class MilvusIndex(EmbeddingIndex): self.consistency_level = consistency_level async def delete(self): - if self.client.has_collection(self.collection_name): - self.client.drop_collection(collection_name=self.collection_name) + if await asyncio.to_thread(self.client.has_collection, self.collection_name): + await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" ) - if not self.client.has_collection(self.collection_name): - self.client.create_collection( + if not await asyncio.to_thread(self.client.has_collection, self.collection_name): + await asyncio.to_thread( + self.client.create_collection, self.collection_name, dimension=len(embeddings[0]), auto_id=True, @@ -62,7 +64,8 @@ class MilvusIndex(EmbeddingIndex): } ) try: - self.client.insert( + await asyncio.to_thread( + self.client.insert, self.collection_name, data=data, ) @@ -71,7 +74,8 @@ class MilvusIndex(EmbeddingIndex): raise e async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - search_res = self.client.search( + search_res = await asyncio.to_thread( + self.client.search, collection_name=self.collection_name, data=[embedding], limit=k,