mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 02:32:40 +00:00
chore: Updating Milvus Client calls to be non-blocking
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
37b6da37ba
commit
ad9b8da796
1 changed files with 10 additions and 6 deletions
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -35,15 +36,16 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.consistency_level = consistency_level
|
self.consistency_level = consistency_level
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if self.client.has_collection(self.collection_name):
|
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
self.client.drop_collection(collection_name=self.collection_name)
|
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
||||||
|
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
if not self.client.has_collection(self.collection_name):
|
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
self.client.create_collection(
|
await asyncio.to_thread(
|
||||||
|
self.client.create_collection,
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
dimension=len(embeddings[0]),
|
dimension=len(embeddings[0]),
|
||||||
auto_id=True,
|
auto_id=True,
|
||||||
|
@ -62,7 +64,8 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.client.insert(
|
await asyncio.to_thread(
|
||||||
|
self.client.insert,
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
@ -71,7 +74,8 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
search_res = self.client.search(
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
limit=k,
|
limit=k,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue