mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat(client): migrate MilvusClient to AsyncMilvusClient
The commit makes the follwing changes. - Import statements updated: MilvusClient → AsyncMilvusClient - Removed asyncio.to_thread() wrappers: All Milvus operations now use native async/await - Test compatibility: Mock objects and fixtures updated to work with AsyncMilvusClient Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
fcf649b97a
commit
1db4800c9c
4 changed files with 413 additions and 27 deletions
|
|
@ -9,20 +9,26 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from chromadb import PersistentClient
|
||||
from pymilvus import AsyncMilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
|
||||
|
|
@ -141,7 +147,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
await index.initialize()
|
||||
index.db_path = db_path
|
||||
yield index
|
||||
index.delete()
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -170,6 +176,48 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def milvus_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||
client = AsyncMilvusClient(milvus_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = milvus_vec_db_path
|
||||
yield index
|
||||
# Proper cleanup: close the async client
|
||||
await client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_db_path,
|
||||
kvstore=SqliteKVStoreConfig(),
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=adapter.metadata_collection_name,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue