mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
feat: migrate from MilvusClient to AsyncMilvusClient
<!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
parent
81109a0f72
commit
f75ec332b5
2 changed files with 103 additions and 107 deletions
|
@ -4,7 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
@ -12,7 +11,7 @@ import re
|
|||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import DataType, MilvusClient
|
||||
from pymilvus import AsyncMilvusClient, DataType
|
||||
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
|
@ -54,7 +53,11 @@ def sanitize_collection_name(name: str) -> str:
|
|||
|
||||
class MilvusIndex(EmbeddingIndex):
|
||||
def __init__(
|
||||
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
|
||||
self,
|
||||
client: AsyncMilvusClient,
|
||||
collection_name: str,
|
||||
consistency_level="Strong",
|
||||
kvstore: KVStore | None = None,
|
||||
):
|
||||
self.client = client
|
||||
self.collection_name = sanitize_collection_name(collection_name)
|
||||
|
@ -62,16 +65,15 @@ class MilvusIndex(EmbeddingIndex):
|
|||
self.kvstore = kvstore
|
||||
|
||||
async def delete(self):
|
||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
|
||||
if await self.client.collection_exists(self.collection_name):
|
||||
await self.client.drop_collection(collection_name=self.collection_name)
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||
)
|
||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
if not await self.client.collection_exists(self.collection_name):
|
||||
await self.client.create_collection(
|
||||
self.collection_name,
|
||||
dimension=len(embeddings[0]),
|
||||
auto_id=True,
|
||||
|
@ -88,8 +90,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
}
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.client.insert,
|
||||
await self.client.insert(
|
||||
self.collection_name,
|
||||
data=data,
|
||||
)
|
||||
|
@ -98,8 +99,7 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise e
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
search_res = await asyncio.to_thread(
|
||||
self.client.search,
|
||||
search_res = await self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
data=[embedding],
|
||||
limit=k,
|
||||
|
@ -139,57 +139,68 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.client: AsyncMilvusClient | None = None
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
if self.config.kvstore is not None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
# Initialize client first before using it
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
self.client = AsyncMilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
else:
|
||||
logger.info(f"Connecting to Milvus Lite at: {self.config.db_path}")
|
||||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
self.client = AsyncMilvusClient(uri=uri)
|
||||
|
||||
# Now load stored vector databases
|
||||
if self.kvstore is not None:
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.model_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
if self.client is not None:
|
||||
await self.client.close()
|
||||
|
||||
async def register_vector_db(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
) -> None:
|
||||
def _ensure_client_initialized(self) -> AsyncMilvusClient:
|
||||
"""Ensure the client is initialized and return it."""
|
||||
if self.client is None:
|
||||
raise RuntimeError("Milvus client is not initialized. Call initialize() first.")
|
||||
return self.client
|
||||
|
||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
consistency_level = self.config.consistency_level
|
||||
else:
|
||||
consistency_level = "Strong"
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(self.client, vector_db.identifier, consistency_level=consistency_level),
|
||||
index=MilvusIndex(
|
||||
self._ensure_client_initialized(), vector_db.identifier, consistency_level=consistency_level
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
|
||||
|
@ -199,17 +210,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
|
||||
index = VectorDBWithIndex(
|
||||
vector_db=vector_db,
|
||||
index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db_id] = index
|
||||
return index
|
||||
# Vector DB should be registered before use
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found. Please register it first.")
|
||||
|
||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||
if vector_db_id in self.cache:
|
||||
|
@ -275,14 +277,17 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
) -> None:
|
||||
"""Save vector store file metadata to Milvus database."""
|
||||
if store_id not in self.openai_vector_stores:
|
||||
store_info = await self._load_openai_vector_stores(store_id)
|
||||
if not store_info:
|
||||
# Reload all vector stores to check if the store exists
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
if store_id not in self.openai_vector_stores:
|
||||
logger.error(f"OpenAI vector store {store_id} not found")
|
||||
raise ValueError(f"No vector store found with id {store_id}")
|
||||
|
||||
try:
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||
file_schema = MilvusClient.create_schema(
|
||||
client = self._ensure_client_initialized()
|
||||
|
||||
if not await client.collection_exists("openai_vector_store_files"):
|
||||
file_schema = AsyncMilvusClient.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=True,
|
||||
description="Metadata for OpenAI vector store files",
|
||||
|
@ -294,14 +299,13 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
||||
file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
await client.create_collection(
|
||||
collection_name="openai_vector_store_files",
|
||||
schema=file_schema,
|
||||
)
|
||||
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
||||
content_schema = MilvusClient.create_schema(
|
||||
if not await client.collection_exists("openai_vector_store_files_contents"):
|
||||
content_schema = AsyncMilvusClient.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=True,
|
||||
description="Contents for OpenAI vector store files",
|
||||
|
@ -314,8 +318,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512)
|
||||
content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535)
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.client.create_collection,
|
||||
await client.create_collection(
|
||||
collection_name="openai_vector_store_files_contents",
|
||||
schema=content_schema,
|
||||
)
|
||||
|
@ -328,44 +331,44 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
"file_info": json.dumps(file_info),
|
||||
}
|
||||
]
|
||||
await asyncio.to_thread(
|
||||
self.client.upsert,
|
||||
await client.upsert(
|
||||
collection_name="openai_vector_store_files",
|
||||
data=file_data,
|
||||
)
|
||||
|
||||
# Save file contents
|
||||
contents_data = [
|
||||
{
|
||||
"chunk_id": content.get("chunk_metadata").get("chunk_id"),
|
||||
"store_file_id": f"{store_id}_{file_id}",
|
||||
"store_id": store_id,
|
||||
"file_id": file_id,
|
||||
"content": json.dumps(content),
|
||||
}
|
||||
for content in file_contents
|
||||
]
|
||||
await asyncio.to_thread(
|
||||
self.client.upsert,
|
||||
collection_name="openai_vector_store_files_contents",
|
||||
data=contents_data,
|
||||
)
|
||||
contents_data = []
|
||||
for content in file_contents:
|
||||
chunk_metadata = content.get("chunk_metadata")
|
||||
if chunk_metadata and chunk_metadata.get("chunk_id"):
|
||||
contents_data.append(
|
||||
{
|
||||
"chunk_id": chunk_metadata.get("chunk_id"),
|
||||
"store_file_id": f"{store_id}_{file_id}",
|
||||
"store_id": store_id,
|
||||
"file_id": file_id,
|
||||
"content": json.dumps(content),
|
||||
}
|
||||
)
|
||||
|
||||
if contents_data:
|
||||
await client.upsert(collection_name="openai_vector_store_files_contents", data=contents_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}")
|
||||
raise
|
||||
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from Milvus database."""
|
||||
try:
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||
client = self._ensure_client_initialized()
|
||||
|
||||
if not await client.collection_exists("openai_vector_store_files"):
|
||||
return {}
|
||||
|
||||
query_filter = f"store_file_id == '{store_id}_{file_id}'"
|
||||
results = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
collection_name="openai_vector_store_files",
|
||||
filter=query_filter,
|
||||
output_fields=["file_info"],
|
||||
results = await client.query(
|
||||
collection_name="openai_vector_store_files", filter=query_filter, output_fields=["file_info"]
|
||||
)
|
||||
|
||||
if results:
|
||||
|
@ -382,7 +385,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in Milvus database."""
|
||||
try:
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||
client = self._ensure_client_initialized()
|
||||
|
||||
if not await client.collection_exists("openai_vector_store_files"):
|
||||
return
|
||||
|
||||
file_data = [
|
||||
|
@ -393,11 +398,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
"file_info": json.dumps(file_info),
|
||||
}
|
||||
]
|
||||
await asyncio.to_thread(
|
||||
self.client.upsert,
|
||||
collection_name="openai_vector_store_files",
|
||||
data=file_data,
|
||||
)
|
||||
await client.upsert(collection_name="openai_vector_store_files", data=file_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}")
|
||||
raise
|
||||
|
@ -405,14 +406,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from Milvus database."""
|
||||
try:
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
||||
client = self._ensure_client_initialized()
|
||||
|
||||
if not await client.collection_exists("openai_vector_store_files_contents"):
|
||||
return []
|
||||
|
||||
query_filter = (
|
||||
f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'"
|
||||
)
|
||||
results = await asyncio.to_thread(
|
||||
self.client.query,
|
||||
results = await client.query(
|
||||
collection_name="openai_vector_store_files_contents",
|
||||
filter=query_filter,
|
||||
output_fields=["chunk_id", "store_id", "file_id", "content"],
|
||||
|
@ -433,21 +435,15 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
|||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from Milvus database."""
|
||||
try:
|
||||
if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"):
|
||||
client = self._ensure_client_initialized()
|
||||
|
||||
if not await client.collection_exists("openai_vector_store_files"):
|
||||
return
|
||||
|
||||
query_filter = f"store_file_id in ['{store_id}_{file_id}']"
|
||||
await asyncio.to_thread(
|
||||
self.client.delete,
|
||||
collection_name="openai_vector_store_files",
|
||||
filter=query_filter,
|
||||
)
|
||||
if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"):
|
||||
await asyncio.to_thread(
|
||||
self.client.delete,
|
||||
collection_name="openai_vector_store_files_contents",
|
||||
filter=query_filter,
|
||||
)
|
||||
await client.delete(collection_name="openai_vector_store_files", filter=query_filter)
|
||||
if await client.collection_exists("openai_vector_store_files_contents"):
|
||||
await client.delete(collection_name="openai_vector_store_files_contents", filter=query_filter)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}")
|
||||
|
|
|
@ -11,7 +11,7 @@ from unittest.mock import AsyncMock
|
|||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from pymilvus import Collection, MilvusClient, connections
|
||||
from pymilvus import AsyncMilvusClient, Collection, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
|
@ -52,7 +52,7 @@ async def unique_kvstore_config(tmp_path_factory):
|
|||
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / "test_milvus.db")
|
||||
client = MilvusClient(db_path)
|
||||
client = AsyncMilvusClient(uri=db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
|
@ -148,7 +148,7 @@ async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_co
|
|||
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
|
||||
|
||||
assert milvus_vec_index.client is not None
|
||||
assert isinstance(milvus_vec_index.client, MilvusClient)
|
||||
assert isinstance(milvus_vec_index.client, AsyncMilvusClient)
|
||||
assert tmp_milvus_vec_adapter.cache is not None
|
||||
# registering a vector won't update the cache or openai_vector_store collection name
|
||||
assert (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue