feat: migrate from MilvusClient to AsyncMilvusClient

<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
This PR adds static type coverage to `llama-stack`

Part of https://github.com/meta-llama/llama-stack/issues/2647

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 15:44:39 +02:00
parent 81109a0f72
commit f75ec332b5
2 changed files with 103 additions and 107 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import json import json
import logging import logging
import os import os
@ -12,7 +11,7 @@ import re
from typing import Any from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient from pymilvus import AsyncMilvusClient, DataType
from llama_stack.apis.files.files import Files from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
@ -54,7 +53,11 @@ def sanitize_collection_name(name: str) -> str:
class MilvusIndex(EmbeddingIndex): class MilvusIndex(EmbeddingIndex):
def __init__( def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
): ):
self.client = client self.client = client
self.collection_name = sanitize_collection_name(collection_name) self.collection_name = sanitize_collection_name(collection_name)
@ -62,16 +65,15 @@ class MilvusIndex(EmbeddingIndex):
self.kvstore = kvstore self.kvstore = kvstore
async def delete(self): async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name): if await self.client.collection_exists(self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) await self.client.drop_collection(collection_name=self.collection_name)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not await asyncio.to_thread(self.client.has_collection, self.collection_name): if not await self.client.collection_exists(self.collection_name):
await asyncio.to_thread( await self.client.create_collection(
self.client.create_collection,
self.collection_name, self.collection_name,
dimension=len(embeddings[0]), dimension=len(embeddings[0]),
auto_id=True, auto_id=True,
@ -88,8 +90,7 @@ class MilvusIndex(EmbeddingIndex):
} }
) )
try: try:
await asyncio.to_thread( await self.client.insert(
self.client.insert,
self.collection_name, self.collection_name,
data=data, data=data,
) )
@ -98,8 +99,7 @@ class MilvusIndex(EmbeddingIndex):
raise e raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread( search_res = await self.client.search(
self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
data=[embedding], data=[embedding],
limit=k, limit=k,
@ -139,57 +139,68 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> None: ) -> None:
self.config = config self.config = config
self.cache = {} self.cache = {}
self.client = None self.client: AsyncMilvusClient | None = None
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api self.files_api = files_api
self.kvstore: KVStore | None = None self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) if self.config.kvstore is not None:
start_key = VECTOR_DBS_PREFIX self.kvstore = await kvstore_impl(self.config.kvstore)
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs: # Initialize client first before using it
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
if isinstance(self.config, RemoteMilvusVectorIOConfig): if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}") logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else: else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}") logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path) uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri) self.client = AsyncMilvusClient(uri=uri)
# Now load stored vector databases
if self.kvstore is not None:
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() if self.client is not None:
await self.client.close()
async def register_vector_db( def _ensure_client_initialized(self) -> AsyncMilvusClient:
self, """Ensure the client is initialized and return it."""
vector_db: VectorDB, if self.client is None:
) -> None: raise RuntimeError("Milvus client is not initialized. Call initialize() first.")
return self.client
async def register_vector_db(self, vector_db: VectorDB) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig): if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level consistency_level = self.config.consistency_level
else: else:
consistency_level = "Strong" consistency_level = "Strong"
index = VectorDBWithIndex( index = VectorDBWithIndex(
vector_db=vector_db, vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level), index=MilvusIndex(
self._ensure_client_initialized(), vector_db.identifier, consistency_level=consistency_level
),
inference_api=self.inference_api, inference_api=self.inference_api,
) )
@ -199,17 +210,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if vector_db_id in self.cache: if vector_db_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) # Vector DB should be registered before use
if not vector_db: raise ValueError(f"Vector DB {vector_db_id} not found. Please register it first.")
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache: if vector_db_id in self.cache:
@ -275,14 +277,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> None: ) -> None:
"""Save vector store file metadata to Milvus database.""" """Save vector store file metadata to Milvus database."""
if store_id not in self.openai_vector_stores: if store_id not in self.openai_vector_stores:
store_info = await self._load_openai_vector_stores(store_id) # Reload all vector stores to check if the store exists
if not store_info: self.openai_vector_stores = await self._load_openai_vector_stores()
if store_id not in self.openai_vector_stores:
logger.error(f"OpenAI vector store {store_id} not found") logger.error(f"OpenAI vector store {store_id} not found")
raise ValueError(f"No vector store found with id {store_id}") raise ValueError(f"No vector store found with id {store_id}")
try: try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): client = self._ensure_client_initialized()
file_schema = MilvusClient.create_schema(
if not await client.collection_exists("openai_vector_store_files"):
file_schema = AsyncMilvusClient.create_schema(
auto_id=False, auto_id=False,
enable_dynamic_field=True, enable_dynamic_field=True,
description="Metadata for OpenAI vector store files", description="Metadata for OpenAI vector store files",
@ -294,14 +299,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512) file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535) file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535)
await asyncio.to_thread( await client.create_collection(
self.client.create_collection,
collection_name="openai_vector_store_files", collection_name="openai_vector_store_files",
schema=file_schema, schema=file_schema,
) )
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): if not await client.collection_exists("openai_vector_store_files_contents"):
content_schema = MilvusClient.create_schema( content_schema = AsyncMilvusClient.create_schema(
auto_id=False, auto_id=False,
enable_dynamic_field=True, enable_dynamic_field=True,
description="Contents for OpenAI vector store files", description="Contents for OpenAI vector store files",
@ -314,8 +318,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512) content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535) content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535)
await asyncio.to_thread( await client.create_collection(
self.client.create_collection,
collection_name="openai_vector_store_files_contents", collection_name="openai_vector_store_files_contents",
schema=content_schema, schema=content_schema,
) )
@ -328,44 +331,44 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
"file_info": json.dumps(file_info), "file_info": json.dumps(file_info),
} }
] ]
await asyncio.to_thread( await client.upsert(
self.client.upsert,
collection_name="openai_vector_store_files", collection_name="openai_vector_store_files",
data=file_data, data=file_data,
) )
# Save file contents # Save file contents
contents_data = [ contents_data = []
{ for content in file_contents:
"chunk_id": content.get("chunk_metadata").get("chunk_id"), chunk_metadata = content.get("chunk_metadata")
"store_file_id": f"{store_id}_{file_id}", if chunk_metadata and chunk_metadata.get("chunk_id"):
"store_id": store_id, contents_data.append(
"file_id": file_id, {
"content": json.dumps(content), "chunk_id": chunk_metadata.get("chunk_id"),
} "store_file_id": f"{store_id}_{file_id}",
for content in file_contents "store_id": store_id,
] "file_id": file_id,
await asyncio.to_thread( "content": json.dumps(content),
self.client.upsert, }
collection_name="openai_vector_store_files_contents", )
data=contents_data,
) if contents_data:
await client.upsert(collection_name="openai_vector_store_files_contents", data=contents_data)
except Exception as e: except Exception as e:
logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}") logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from Milvus database.""" """Load vector store file metadata from Milvus database."""
try: try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files"):
return {} return {}
query_filter = f"store_file_id == '{store_id}_{file_id}'" query_filter = f"store_file_id == '{store_id}_{file_id}'"
results = await asyncio.to_thread( results = await client.query(
self.client.query, collection_name="openai_vector_store_files", filter=query_filter, output_fields=["file_info"]
collection_name="openai_vector_store_files",
filter=query_filter,
output_fields=["file_info"],
) )
if results: if results:
@ -382,7 +385,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Milvus database.""" """Update vector store file metadata in Milvus database."""
try: try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files"):
return return
file_data = [ file_data = [
@ -393,11 +398,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
"file_info": json.dumps(file_info), "file_info": json.dumps(file_info),
} }
] ]
await asyncio.to_thread( await client.upsert(collection_name="openai_vector_store_files", data=file_data)
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
except Exception as e: except Exception as e:
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}") logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise raise
@ -405,14 +406,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from Milvus database.""" """Load vector store file contents from Milvus database."""
try: try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files_contents"):
return [] return []
query_filter = ( query_filter = (
f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'" f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'"
) )
results = await asyncio.to_thread( results = await client.query(
self.client.query,
collection_name="openai_vector_store_files_contents", collection_name="openai_vector_store_files_contents",
filter=query_filter, filter=query_filter,
output_fields=["chunk_id", "store_id", "file_id", "content"], output_fields=["chunk_id", "store_id", "file_id", "content"],
@ -433,21 +435,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from Milvus database.""" """Delete vector store file metadata from Milvus database."""
try: try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files"):
return return
query_filter = f"store_file_id in ['{store_id}_{file_id}']" query_filter = f"store_file_id in ['{store_id}_{file_id}']"
await asyncio.to_thread( await client.delete(collection_name="openai_vector_store_files", filter=query_filter)
self.client.delete, if await client.collection_exists("openai_vector_store_files_contents"):
collection_name="openai_vector_store_files", await client.delete(collection_name="openai_vector_store_files_contents", filter=query_filter)
filter=query_filter,
)
if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
await asyncio.to_thread(
self.client.delete,
collection_name="openai_vector_store_files_contents",
filter=query_filter,
)
except Exception as e: except Exception as e:
logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}") logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")

View file

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from pymilvus import Collection, MilvusClient, connections from pymilvus import AsyncMilvusClient, Collection, connections
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
@ -52,7 +52,7 @@ async def unique_kvstore_config(tmp_path_factory):
async def milvus_vec_index(embedding_dimension, tmp_path_factory): async def milvus_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp() temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_milvus.db") db_path = str(temp_dir / "test_milvus.db")
client = MilvusClient(db_path) client = AsyncMilvusClient(uri=db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}" name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=db_path) connections.connect(alias=MILVUS_ALIAS, uri=db_path)
index = MilvusIndex(client, name, consistency_level="Strong") index = MilvusIndex(client, name, consistency_level="Strong")
@ -148,7 +148,7 @@ async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_co
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data) await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
assert milvus_vec_index.client is not None assert milvus_vec_index.client is not None
assert isinstance(milvus_vec_index.client, MilvusClient) assert isinstance(milvus_vec_index.client, AsyncMilvusClient)
assert tmp_milvus_vec_adapter.cache is not None assert tmp_milvus_vec_adapter.cache is not None
# registering a vector won't update the cache or openai_vector_store collection name # registering a vector won't update the cache or openai_vector_store collection name
assert ( assert (