feat: migrate from MilvusClient to AsyncMilvusClient

<!-- Provide a short summary of what this PR does and why. Link to
relevant issues if applicable. -->
This PR adds static type coverage to `llama-stack`

Part of https://github.com/meta-llama/llama-stack/issues/2647

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

<!-- Describe the tests you ran to verify your changes with result
summaries. *Provide clear instructions so the plan can be easily
re-executed.* -->

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 15:44:39 +02:00
parent 81109a0f72
commit f75ec332b5
2 changed files with 103 additions and 107 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
import os
@ -12,7 +11,7 @@ import re
from typing import Any
from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient
from pymilvus import AsyncMilvusClient, DataType
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
@ -54,7 +53,11 @@ def sanitize_collection_name(name: str) -> str:
class MilvusIndex(EmbeddingIndex):
def __init__(
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
self,
client: AsyncMilvusClient,
collection_name: str,
consistency_level="Strong",
kvstore: KVStore | None = None,
):
self.client = client
self.collection_name = sanitize_collection_name(collection_name)
@ -62,16 +65,15 @@ class MilvusIndex(EmbeddingIndex):
self.kvstore = kvstore
async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
if await self.client.collection_exists(self.collection_name):
await self.client.drop_collection(collection_name=self.collection_name)
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(
self.client.create_collection,
if not await self.client.collection_exists(self.collection_name):
await self.client.create_collection(
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
@ -88,8 +90,7 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
await asyncio.to_thread(
self.client.insert,
await self.client.insert(
self.collection_name,
data=data,
)
@ -98,8 +99,7 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = await asyncio.to_thread(
self.client.search,
search_res = await self.client.search(
collection_name=self.collection_name,
data=[embedding],
limit=k,
@ -139,57 +139,68 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> None:
self.config = config
self.cache = {}
self.client = None
self.client: AsyncMilvusClient | None = None
self.inference_api = inference_api
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
if self.config.kvstore is not None:
self.kvstore = await kvstore_impl(self.config.kvstore)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
# Initialize client first before using it
if isinstance(self.config, RemoteMilvusVectorIOConfig):
logger.info(f"Connecting to Milvus server at {self.config.uri}")
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
else:
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
self.client = AsyncMilvusClient(uri=uri)
# Now load stored vector databases
if self.kvstore is not None:
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for vector_db_data in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(vector_db_data)
index = VectorDBWithIndex(
vector_db,
index=MilvusIndex(
client=self.client,
collection_name=vector_db.identifier,
consistency_level=self.config.consistency_level,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
if self.client is not None:
await self.client.close()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
def _ensure_client_initialized(self) -> AsyncMilvusClient:
"""Ensure the client is initialized and return it."""
if self.client is None:
raise RuntimeError("Milvus client is not initialized. Call initialize() first.")
return self.client
async def register_vector_db(self, vector_db: VectorDB) -> None:
if isinstance(self.config, RemoteMilvusVectorIOConfig):
consistency_level = self.config.consistency_level
else:
consistency_level = "Strong"
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
index=MilvusIndex(
self._ensure_client_initialized(), vector_db.identifier, consistency_level=consistency_level
),
inference_api=self.inference_api,
)
@ -199,17 +210,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
# Vector DB should be registered before use
raise ValueError(f"Vector DB {vector_db_id} not found. Please register it first.")
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id in self.cache:
@ -275,14 +277,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
) -> None:
"""Save vector store file metadata to Milvus database."""
if store_id not in self.openai_vector_stores:
store_info = await self._load_openai_vector_stores(store_id)
if not store_info:
# Reload all vector stores to check if the store exists
self.openai_vector_stores = await self._load_openai_vector_stores()
if store_id not in self.openai_vector_stores:
logger.error(f"OpenAI vector store {store_id} not found")
raise ValueError(f"No vector store found with id {store_id}")
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
file_schema = MilvusClient.create_schema(
client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files"):
file_schema = AsyncMilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=True,
description="Metadata for OpenAI vector store files",
@ -294,14 +299,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535)
await asyncio.to_thread(
self.client.create_collection,
await client.create_collection(
collection_name="openai_vector_store_files",
schema=file_schema,
)
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
content_schema = MilvusClient.create_schema(
if not await client.collection_exists("openai_vector_store_files_contents"):
content_schema = AsyncMilvusClient.create_schema(
auto_id=False,
enable_dynamic_field=True,
description="Contents for OpenAI vector store files",
@ -314,8 +318,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535)
await asyncio.to_thread(
self.client.create_collection,
await client.create_collection(
collection_name="openai_vector_store_files_contents",
schema=content_schema,
)
@ -328,44 +331,44 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
await client.upsert(
collection_name="openai_vector_store_files",
data=file_data,
)
# Save file contents
contents_data = [
{
"chunk_id": content.get("chunk_metadata").get("chunk_id"),
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"content": json.dumps(content),
}
for content in file_contents
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files_contents",
data=contents_data,
)
contents_data = []
for content in file_contents:
chunk_metadata = content.get("chunk_metadata")
if chunk_metadata and chunk_metadata.get("chunk_id"):
contents_data.append(
{
"chunk_id": chunk_metadata.get("chunk_id"),
"store_file_id": f"{store_id}_{file_id}",
"store_id": store_id,
"file_id": file_id,
"content": json.dumps(content),
}
)
if contents_data:
await client.upsert(collection_name="openai_vector_store_files_contents", data=contents_data)
except Exception as e:
logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
raise
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
"""Load vector store file metadata from Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files"):
return {}
query_filter = f"store_file_id == '{store_id}_{file_id}'"
results = await asyncio.to_thread(
self.client.query,
collection_name="openai_vector_store_files",
filter=query_filter,
output_fields=["file_info"],
results = await client.query(
collection_name="openai_vector_store_files", filter=query_filter, output_fields=["file_info"]
)
if results:
@ -382,7 +385,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
"""Update vector store file metadata in Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files"):
return
file_data = [
@ -393,11 +398,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
"file_info": json.dumps(file_info),
}
]
await asyncio.to_thread(
self.client.upsert,
collection_name="openai_vector_store_files",
data=file_data,
)
await client.upsert(collection_name="openai_vector_store_files", data=file_data)
except Exception as e:
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
raise
@ -405,14 +406,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
"""Load vector store file contents from Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files_contents"):
return []
query_filter = (
f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'"
)
results = await asyncio.to_thread(
self.client.query,
results = await client.query(
collection_name="openai_vector_store_files_contents",
filter=query_filter,
output_fields=["chunk_id", "store_id", "file_id", "content"],
@ -433,21 +435,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from Milvus database."""
try:
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
client = self._ensure_client_initialized()
if not await client.collection_exists("openai_vector_store_files"):
return
query_filter = f"store_file_id in ['{store_id}_{file_id}']"
await asyncio.to_thread(
self.client.delete,
collection_name="openai_vector_store_files",
filter=query_filter,
)
if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
await asyncio.to_thread(
self.client.delete,
collection_name="openai_vector_store_files_contents",
filter=query_filter,
)
await client.delete(collection_name="openai_vector_store_files", filter=query_filter)
if await client.collection_exists("openai_vector_store_files_contents"):
await client.delete(collection_name="openai_vector_store_files_contents", filter=query_filter)
except Exception as e:
logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")

View file

@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
import numpy as np
import pytest
import pytest_asyncio
from pymilvus import Collection, MilvusClient, connections
from pymilvus import AsyncMilvusClient, Collection, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
@ -52,7 +52,7 @@ async def unique_kvstore_config(tmp_path_factory):
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_milvus.db")
client = MilvusClient(db_path)
client = AsyncMilvusClient(uri=db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
@ -148,7 +148,7 @@ async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_co
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
assert milvus_vec_index.client is not None
assert isinstance(milvus_vec_index.client, MilvusClient)
assert isinstance(milvus_vec_index.client, AsyncMilvusClient)
assert tmp_milvus_vec_adapter.cache is not None
# registering a vector won't update the cache or openai_vector_store collection name
assert (