mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 10:13:05 +00:00
chore: Updating Milvus Client calls to be non-blocking
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
37b6da37ba
commit
ad9b8da796
1 changed files with 10 additions and 6 deletions
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
|
@ -35,15 +36,16 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.consistency_level = consistency_level
|
||||
|
||||
async def delete(self):
|
||||
if self.client.has_collection(self.collection_name):
|
||||
self.client.drop_collection(collection_name=self.collection_name)
|
||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
||||
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
if not self.client.has_collection(self.collection_name):
|
||||
self.client.create_collection(
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
|
@ -62,7 +64,8 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
self.client.insert(
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
|
@ -71,7 +74,8 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise e
|
||||
|
||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = self.client.search(
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
limit=k,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue