mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
feat(client): migrate MilvusClient to AsyncMilvusClient
The commit makes the follwing changes. - Import statements updated: MilvusClient → AsyncMilvusClient - Removed asyncio.to_thread() wrappers: All Milvus operations now use native async/await - Test compatibility: Mock objects and fixtures updated to work with AsyncMilvusClient Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
0e27016cf2
commit
142bd248e7
4 changed files with 76 additions and 66 deletions
|
@ -4,12 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
|
from pymilvus import AnnSearchRequest, AsyncMilvusClient, DataType, Function, FunctionType, RRFRanker, WeightedRanker
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
|
@ -48,7 +47,11 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten
|
||||||
|
|
||||||
class MilvusIndex(EmbeddingIndex):
|
class MilvusIndex(EmbeddingIndex):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
|
self,
|
||||||
|
client: AsyncMilvusClient,
|
||||||
|
collection_name: str,
|
||||||
|
consistency_level="Strong",
|
||||||
|
kvstore: KVStore | None = None,
|
||||||
):
|
):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = sanitize_collection_name(collection_name)
|
self.collection_name = sanitize_collection_name(collection_name)
|
||||||
|
@ -61,15 +64,15 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if await self.client.has_collection(self.collection_name):
|
||||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
await self.client.drop_collection(collection_name=self.collection_name)
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if not await self.client.has_collection(self.collection_name):
|
||||||
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
# Create schema for vector search
|
# Create schema for vector search
|
||||||
schema = self.client.create_schema()
|
schema = self.client.create_schema()
|
||||||
|
@ -123,8 +126,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
)
|
)
|
||||||
schema.add_function(bm25_function)
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await self.client.create_collection(
|
||||||
self.client.create_collection,
|
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
index_params=index_params,
|
index_params=index_params,
|
||||||
|
@ -143,8 +145,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(
|
await self.client.insert(
|
||||||
self.client.insert,
|
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
@ -153,8 +154,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.search(
|
||||||
self.client.search,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
anns_field="vector",
|
anns_field="vector",
|
||||||
|
@ -177,8 +177,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Use Milvus's built-in BM25 search
|
# Use Milvus's built-in BM25 search
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.search(
|
||||||
self.client.search,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[query_string], # Raw text query
|
data=[query_string], # Raw text query
|
||||||
anns_field="sparse", # Use sparse field for BM25
|
anns_field="sparse", # Use sparse field for BM25
|
||||||
|
@ -219,8 +218,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
Fallback to simple text search when BM25 search is not available.
|
Fallback to simple text search when BM25 search is not available.
|
||||||
"""
|
"""
|
||||||
# Simple text search using content field
|
# Simple text search using content field
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.query(
|
||||||
self.client.query,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
filter='content like "%{content}%"',
|
filter='content like "%{content}%"',
|
||||||
filter_params={"content": query_string},
|
filter_params={"content": query_string},
|
||||||
|
@ -267,8 +265,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
impact_factor = (reranker_params or {}).get("impact_factor", 60.0)
|
||||||
rerank = RRFRanker(impact_factor)
|
rerank = RRFRanker(impact_factor)
|
||||||
|
|
||||||
search_res = await asyncio.to_thread(
|
search_res = await self.client.hybrid_search(
|
||||||
self.client.hybrid_search,
|
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
reqs=search_requests,
|
reqs=search_requests,
|
||||||
ranker=rerank,
|
ranker=rerank,
|
||||||
|
@ -294,9 +291,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
try:
|
try:
|
||||||
# Use IN clause with square brackets and single quotes for VARCHAR field
|
# Use IN clause with square brackets and single quotes for VARCHAR field
|
||||||
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
|
chunk_ids_str = ", ".join(f"'{chunk_id}'" for chunk_id in chunk_ids)
|
||||||
await asyncio.to_thread(
|
await self.client.delete(collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]")
|
||||||
self.client.delete, collection_name=self.collection_name, filter=f"chunk_id in [{chunk_ids_str}]"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
|
logger.error(f"Error deleting chunks from Milvus collection {self.collection_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
@ -340,17 +335,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||||
else:
|
else:
|
||||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
self.client = MilvusClient(uri=uri)
|
self.client = AsyncMilvusClient(uri=uri)
|
||||||
|
|
||||||
# Load existing OpenAI vector stores into the in-memory cache
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
await self.initialize_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
await self.client.close()
|
||||||
|
|
||||||
async def register_vector_db(
|
async def register_vector_db(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from chromadb import PersistentClient
|
from chromadb import PersistentClient
|
||||||
from pymilvus import MilvusClient, connections
|
from pymilvus import AsyncMilvusClient, connections
|
||||||
|
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||||
|
@ -139,7 +139,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
||||||
await index.initialize()
|
await index.initialize()
|
||||||
index.db_path = db_path
|
index.db_path = db_path
|
||||||
yield index
|
yield index
|
||||||
index.delete()
|
await index.delete()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -176,12 +176,14 @@ def milvus_vec_db_path(tmp_path_factory):
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||||
client = MilvusClient(milvus_vec_db_path)
|
client = AsyncMilvusClient(milvus_vec_db_path)
|
||||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||||
index.db_path = milvus_vec_db_path
|
index.db_path = milvus_vec_db_path
|
||||||
yield index
|
yield index
|
||||||
|
# Proper cleanup: close the async client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -14,7 +14,7 @@ from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
# Mock the entire pymilvus module
|
# Mock the entire pymilvus module
|
||||||
pymilvus_mock = MagicMock()
|
pymilvus_mock = MagicMock()
|
||||||
pymilvus_mock.DataType = MagicMock()
|
pymilvus_mock.DataType = MagicMock()
|
||||||
pymilvus_mock.MilvusClient = MagicMock
|
pymilvus_mock.AsyncMilvusClient = MagicMock
|
||||||
pymilvus_mock.RRFRanker = MagicMock
|
pymilvus_mock.RRFRanker = MagicMock
|
||||||
pymilvus_mock.WeightedRanker = MagicMock
|
pymilvus_mock.WeightedRanker = MagicMock
|
||||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||||
|
@ -40,16 +40,17 @@ async def mock_milvus_client() -> MagicMock:
|
||||||
"""Create a mock Milvus client with common method behaviors."""
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
client = MagicMock()
|
client = MagicMock()
|
||||||
|
|
||||||
# Mock collection operations
|
# Mock async collection operations
|
||||||
client.has_collection.return_value = False # Initially no collection
|
client.has_collection = AsyncMock(return_value=False) # Initially no collection
|
||||||
client.create_collection.return_value = None
|
client.create_collection = AsyncMock(return_value=None)
|
||||||
client.drop_collection.return_value = None
|
client.drop_collection = AsyncMock(return_value=None)
|
||||||
|
|
||||||
# Mock insert operation
|
# Mock async insert operation
|
||||||
client.insert.return_value = {"insert_count": 10}
|
client.insert = AsyncMock(return_value={"insert_count": 10})
|
||||||
|
|
||||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
# Mock async search operation - return mock results (data should be dict, not JSON string)
|
||||||
client.search.return_value = [
|
client.search = AsyncMock(
|
||||||
|
return_value=[
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"id": 0,
|
"id": 0,
|
||||||
|
@ -63,9 +64,11 @@ async def mock_milvus_client() -> MagicMock:
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
# Mock async query operation for keyword search (data should be dict, not JSON string)
|
||||||
client.query.return_value = [
|
client.query = AsyncMock(
|
||||||
|
return_value=[
|
||||||
{
|
{
|
||||||
"chunk_id": "chunk1",
|
"chunk_id": "chunk1",
|
||||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||||
|
@ -82,6 +85,16 @@ async def mock_milvus_client() -> MagicMock:
|
||||||
"score": 0.7,
|
"score": 0.7,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock async hybrid_search operation
|
||||||
|
client.hybrid_search = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
# Mock async delete operation
|
||||||
|
client.delete = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
# Mock async close operation
|
||||||
|
client.close = AsyncMock(return_value=None)
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
|
@ -30,12 +30,12 @@ async def test_initialize_index(vector_index):
|
||||||
|
|
||||||
|
|
||||||
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
|
||||||
vector_index.delete()
|
await vector_index.delete()
|
||||||
vector_index.initialize()
|
await vector_index.initialize()
|
||||||
await vector_index.add_chunks(sample_chunks, sample_embeddings)
|
await vector_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
|
||||||
assert resp.chunks[0].content == sample_chunks[0].content
|
assert resp.chunks[0].content == sample_chunks[0].content
|
||||||
vector_index.delete()
|
await vector_index.delete()
|
||||||
|
|
||||||
|
|
||||||
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue