chore: Updating Milvus Client calls to be non-blocking

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-03-28 21:39:29 -04:00
parent 37b6da37ba
commit ad9b8da796

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import hashlib import hashlib
import logging import logging
import os import os
@ -35,15 +36,16 @@ class MilvusIndex(EmbeddingIndex):
self.consistency_level = consistency_level self.consistency_level = consistency_level
async def delete(self): async def delete(self):
if self.client.has_collection(self.collection_name): if await asyncio.to_thread(self.client.has_collection, self.collection_name):
self.client.drop_collection(collection_name=self.collection_name) await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not self.client.has_collection(self.collection_name): if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
self.client.create_collection( await asyncio.to_thread(
self.client.create_collection,
self.collection_name, self.collection_name,
dimension=len(embeddings[0]), dimension=len(embeddings[0]),
auto_id=True, auto_id=True,
@ -62,7 +64,8 @@ class MilvusIndex(EmbeddingIndex):
} }
) )
try: try:
self.client.insert( await asyncio.to_thread(
self.client.insert,
self.collection_name, self.collection_name,
data=data, data=data,
) )
@ -71,7 +74,8 @@ class MilvusIndex(EmbeddingIndex):
raise e raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search( search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
data=[embedding], data=[embedding],
limit=k, limit=k,