chore: Updating Milvus Client calls to be non-blocking

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-03-28 21:39:29 -04:00
parent 37b6da37ba
commit ad9b8da796

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import hashlib
import logging
import os
@ -35,15 +36,16 @@ class MilvusIndex(EmbeddingIndex):
self.consistency_level = consistency_level
async def delete(self):
if self.client.has_collection(self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not self.client.has_collection(self.collection_name):
self.client.create_collection(
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
@ -62,7 +64,8 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
self.client.insert(
await asyncio.to_thread(
self.client.insert,
self.collection_name,
data=data,
)
@ -71,7 +74,8 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search(
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[embedding],
limit=k,