diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 5e0a449b8..25fe237c0 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -8,10 +8,11 @@ import asyncio import json import logging import os +import re from typing import Any from numpy.typing import NDArray -from pymilvus import MilvusClient +from pymilvus import DataType, MilvusClient from llama_stack.apis.files.files import Files from llama_stack.apis.inference import Inference, InterleavedContent @@ -43,12 +44,20 @@ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:milvus:{VERSION}::" +def sanitize_collection_name(name: str) -> str: + """ + Sanitize collection name to ensure it only contains numbers, letters, and underscores. + Any other characters are replaced with underscores. + """ + return re.sub(r"[^a-zA-Z0-9_]", "_", name) + + class MilvusIndex(EmbeddingIndex): def __init__( self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None ): self.client = client - self.collection_name = collection_name.replace("-", "_") + self.collection_name = sanitize_collection_name(collection_name) self.consistency_level = consistency_level self.kvstore = kvstore @@ -196,7 +205,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP index = VectorDBWithIndex( vector_db=vector_db, - index=MilvusIndex(client=self.client, collection_name=vector_db.identifier), + index=MilvusIndex(client=self.client, collection_name=vector_db.identifier, kvstore=self.kvstore), inference_api=self.inference_api, ) self.cache[vector_db_id] = index @@ -251,16 +260,6 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.delete(key) - async def _save_openai_vector_store_file( - self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] - ) -> None: - """Save vector store file metadata to Milvus database.""" - assert self.kvstore is not None - key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" - await self.kvstore.set(key=key, value=json.dumps(file_info)) - content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}" - await self.kvstore.set(key=content_key, value=json.dumps(file_contents)) - async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: """Load all vector store metadata from persistent storage.""" assert self.kvstore is not None @@ -273,20 +272,181 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] ) -> None: """Save vector store file metadata to Milvus database.""" - raise NotImplementedError("Files API not yet implemented for Milvus") + if store_id not in self.openai_vector_stores: + store_info = await self._load_openai_vector_stores(store_id) + if not store_info: + logger.error(f"OpenAI vector store {store_id} not found") + raise ValueError(f"No vector store found with id {store_id}") + + try: + if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): + file_schema = MilvusClient.create_schema( + auto_id=False, + enable_dynamic_field=True, + description="Metadata for OpenAI vector store files", + ) + file_schema.add_field( + field_name="store_file_id", datatype=DataType.VARCHAR, is_primary=True, max_length=512 + ) + file_schema.add_field(field_name="store_id", datatype=DataType.VARCHAR, max_length=512) + file_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512) + file_schema.add_field(field_name="file_info", datatype=DataType.VARCHAR, max_length=65535) + + await asyncio.to_thread( + self.client.create_collection, + collection_name="openai_vector_store_files", + schema=file_schema, + ) + + if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): + content_schema = MilvusClient.create_schema( + auto_id=False, + enable_dynamic_field=True, + description="Contents for OpenAI vector store files", + ) + content_schema.add_field( + field_name="chunk_id", datatype=DataType.VARCHAR, is_primary=True, max_length=1024 + ) + content_schema.add_field(field_name="store_file_id", datatype=DataType.VARCHAR, max_length=1024) + content_schema.add_field(field_name="store_id", datatype=DataType.VARCHAR, max_length=512) + content_schema.add_field(field_name="file_id", datatype=DataType.VARCHAR, max_length=512) + content_schema.add_field(field_name="content", datatype=DataType.VARCHAR, max_length=65535) + + await asyncio.to_thread( + self.client.create_collection, + collection_name="openai_vector_store_files_contents", + schema=content_schema, + ) + + file_data = [ + { + "store_file_id": f"{store_id}_{file_id}", + "store_id": store_id, + "file_id": file_id, + "file_info": json.dumps(file_info), + } + ] + await asyncio.to_thread( + self.client.upsert, + collection_name="openai_vector_store_files", + data=file_data, + ) + + # Save file contents + contents_data = [ + { + "chunk_id": content.get("chunk_metadata").get("chunk_id"), + "store_file_id": f"{store_id}_{file_id}", + "store_id": store_id, + "file_id": file_id, + "content": json.dumps(content), + } + for content in file_contents + ] + await asyncio.to_thread( + self.client.upsert, + collection_name="openai_vector_store_files_contents", + data=contents_data, + ) + + except Exception as e: + logger.error(f"Error saving openai vector store file {file_id} for store {store_id}: {e}") async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: """Load vector store file metadata from Milvus database.""" - raise NotImplementedError("Files API not yet implemented for Milvus") + try: + if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): + return {} + + query_filter = f"store_file_id == '{store_id}_{file_id}'" + results = await asyncio.to_thread( + self.client.query, + collection_name="openai_vector_store_files", + filter=query_filter, + output_fields=["file_info"], + ) + + if results: + try: + return json.loads(results[0]["file_info"]) + except json.JSONDecodeError as e: + logger.error(f"Failed to decode file_info for store {store_id}, file {file_id}: {e}") + return {} + return {} + except Exception as e: + logger.error(f"Error loading openai vector store file {file_id} for store {store_id}: {e}") + return {} async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: """Load vector store file contents from Milvus database.""" - raise NotImplementedError("Files API not yet implemented for Milvus") + try: + if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): + return [] + + query_filter = ( + f"store_id == '{store_id}' AND file_id == '{file_id}' AND store_file_id == '{store_id}_{file_id}'" + ) + results = await asyncio.to_thread( + self.client.query, + collection_name="openai_vector_store_files_contents", + filter=query_filter, + output_fields=["chunk_id", "store_id", "file_id", "content"], + ) + + contents = [] + for result in results: + try: + content = json.loads(result["content"]) + contents.append(content) + except json.JSONDecodeError as e: + logger.error(f"Failed to decode content for store {store_id}, file {file_id}: {e}") + return contents + except Exception as e: + logger.error(f"Error loading openai vector store file contents for {file_id} in store {store_id}: {e}") + return [] async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: """Update vector store file metadata in Milvus database.""" - raise NotImplementedError("Files API not yet implemented for Milvus") + try: + if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): + return + + file_data = [ + { + "store_file_id": f"{store_id}_{file_id}", + "store_id": store_id, + "file_id": file_id, + "file_info": json.dumps(file_info), + } + ] + await asyncio.to_thread( + self.client.upsert, + collection_name="openai_vector_store_files", + data=file_data, + ) + except Exception as e: + logger.error(f"Error updating openai vector store file {file_id} for store {store_id}: {e}") + raise async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: """Delete vector store file metadata from Milvus database.""" - raise NotImplementedError("Files API not yet implemented for Milvus") + try: + if not await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files"): + return + + query_filter = f"store_file_id in ['{store_id}_{file_id}']" + await asyncio.to_thread( + self.client.delete, + collection_name="openai_vector_store_files", + filter=query_filter, + ) + if await asyncio.to_thread(self.client.has_collection, "openai_vector_store_files_contents"): + await asyncio.to_thread( + self.client.delete, + collection_name="openai_vector_store_files_contents", + filter=query_filter, + ) + + except Exception as e: + logger.error(f"Error deleting openai vector store file {file_id} for store {store_id}: {e}") + raise diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index e961ac5ec..cc2860e26 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -31,7 +31,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models): vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: - if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]: + if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus"]: return pytest.skip("OpenAI vector stores are not supported by any provider") @@ -524,7 +524,6 @@ def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_s file_ids = valid_file_ids + [failed_file_id] num_failed = len(file_ids) - len(valid_file_ids) - # Create a vector store vector_store = compat_client.vector_stores.create( name="test_store", file_ids=file_ids,