diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index e07175c49..f688be8a4 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -11,8 +11,6 @@ import uuid from typing import Annotated, Any, Literal, Protocol, runtime_checkable -from pydantic import BaseModel, Field - from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.version import LLAMA_STACK_API_V1 @@ -21,6 +19,8 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema +from pydantic import BaseModel, Field + @json_schema_type class ChunkMetadata(BaseModel): @@ -350,7 +350,12 @@ class VectorStoreFileLastError(BaseModel): message: str -VectorStoreFileStatus = Literal["completed"] | Literal["in_progress"] | Literal["cancelled"] | Literal["failed"] +VectorStoreFileStatus = ( + Literal["completed"] + | Literal["in_progress"] + | Literal["cancelled"] + | Literal["failed"] +) register_schema(VectorStoreFileStatus, name="VectorStoreFileStatus") @@ -556,7 +561,9 @@ class VectorIO(Protocol): """ ... - @webmethod(route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1) + @webmethod( + route="/vector_stores/{vector_store_id}", method="GET", level=LLAMA_STACK_API_V1 + ) async def openai_retrieve_vector_store( self, vector_store_id: str, diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 0e3f9d8d9..9c20b9144 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -344,6 +344,64 @@ class VectorIORouter(VectorIO): file_id=file_id, ) + async def openai_create_vector_store_file_batch( + self, + vector_store_id: str, + file_ids: list[str], + attributes: dict[str, Any] | None = None, + chunking_strategy: VectorStoreChunkingStrategy | None = None, + ) -> VectorStoreFileBatchObject: + logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files") + return await self.routing_table.openai_create_vector_store_file_batch( + vector_store_id=vector_store_id, + file_ids=file_ids, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_retrieve_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + ) -> VectorStoreFileBatchObject: + logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}") + return await self.routing_table.openai_retrieve_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + ) + + async def openai_list_files_in_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + after: str | None = None, + before: str | None = None, + filter: str | None = None, + limit: int | None = 20, + order: str | None = "desc", + ) -> VectorStoreFilesListInBatchResponse: + logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}") + return await self.routing_table.openai_list_files_in_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + after=after, + before=before, + filter=filter, + limit=limit, + order=order, + ) + + async def openai_cancel_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + ) -> VectorStoreFileBatchObject: + logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}") + return await self.routing_table.openai_cancel_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + ) + async def health(self) -> dict[str, HealthResponse]: health_statuses = {} timeout = 1 # increasing the timeout to 1 second for health checks diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index 497894064..932bbdba8 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -245,3 +245,65 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): vector_store_id=vector_store_id, file_id=file_id, ) + + async def openai_create_vector_store_file_batch( + self, + vector_store_id: str, + file_ids: list[str], + attributes: dict[str, Any] | None = None, + chunking_strategy: Any | None = None, + ): + await self.assert_action_allowed("update", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_create_vector_store_file_batch( + vector_store_id=vector_store_id, + file_ids=file_ids, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_retrieve_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + ): + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_retrieve_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + ) + + async def openai_list_files_in_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + after: str | None = None, + before: str | None = None, + filter: str | None = None, + limit: int | None = 20, + order: str | None = "desc", + ): + await self.assert_action_allowed("read", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_list_files_in_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + after=after, + before=before, + filter=filter, + limit=limit, + order=order, + ) + + async def openai_cancel_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + ): + await self.assert_action_allowed("update", "vector_db", vector_store_id) + provider = await self.get_provider_impl(vector_store_id) + return await provider.openai_cancel_vector_store_file_batch( + batch_id=batch_id, + vector_store_id=vector_store_id, + ) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 258c6e7aa..d5fde8595 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -206,6 +206,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr self.cache: dict[str, VectorDBWithIndex] = {} self.kvstore: KVStore | None = None self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.openai_file_batches: dict[str, dict[str, Any]] = {} async def initialize(self) -> None: self.kvstore = await kvstore_impl(self.config.kvstore) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index f34f8f6fb..53573e4aa 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -415,6 +415,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc self.files_api = files_api self.cache: dict[str, VectorDBWithIndex] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.openai_file_batches: dict[str, dict[str, Any]] = {} self.kvstore: KVStore | None = None async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index a9ec644ef..6d85b9d16 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -166,6 +166,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP log.info(f"Connecting to Chroma local db at: {self.config.db_path}") self.client = chromadb.PersistentClient(path=self.config.db_path) self.openai_vector_stores = await self._load_openai_vector_stores() + self.openai_file_batches: dict[str, dict[str, Any]] = {} async def shutdown(self) -> None: pass diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index e07e8ff12..f4f7ad8e4 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -317,6 +317,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.kvstore: KVStore | None = None self.vector_db_store = None self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.openai_file_batches: dict[str, dict[str, Any]] = {} self.metadata_collection_name = "openai_vector_stores_metadata" async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 1c140e782..58dbf3618 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -353,6 +353,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco self.kvstore: KVStore | None = None self.vector_db_store = None self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.openai_file_batches: dict[str, dict[str, Any]] = {} self.metadata_collection_name = "openai_vector_stores_metadata" async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index ec3869495..142143128 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -170,6 +170,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP self.vector_db_store = None self.kvstore: KVStore | None = None self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.openai_file_batches: dict[str, dict[str, Any]] = {} self._qdrant_lock = asyncio.Lock() async def initialize(self) -> None: diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 59b6bf124..e99ff00f0 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -170,6 +170,7 @@ class WeaviateVectorIOAdapter( self.kvstore: KVStore | None = None self.vector_db_store = None self.openai_vector_stores: dict[str, dict[str, Any]] = {} + self.openai_file_batches: dict[str, dict[str, Any]] = {} self.metadata_collection_name = "openai_vector_stores_metadata" def _get_client(self) -> weaviate.Client: diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 36432767f..53be11e46 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -55,7 +55,9 @@ VERSION = "v3" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::" -OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::" +OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = ( + f"openai_vector_stores_files_contents:{VERSION}::" +) class OpenAIVectorStoreMixin(ABC): @@ -67,11 +69,14 @@ class OpenAIVectorStoreMixin(ABC): # These should be provided by the implementing class openai_vector_stores: dict[str, dict[str, Any]] + openai_file_batches: dict[str, dict[str, Any]] files_api: Files | None # KV store for persisting OpenAI vector store metadata kvstore: KVStore | None - async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + async def _save_openai_vector_store( + self, store_id: str, store_info: dict[str, Any] + ) -> None: """Save vector store metadata to persistent storage.""" assert self.kvstore key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" @@ -92,7 +97,9 @@ class OpenAIVectorStoreMixin(ABC): stores[info["id"]] = info return stores - async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + async def _update_openai_vector_store( + self, store_id: str, store_info: dict[str, Any] + ) -> None: """Update vector store metadata in persistent storage.""" assert self.kvstore key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" @@ -119,18 +126,26 @@ class OpenAIVectorStoreMixin(ABC): assert self.kvstore meta_key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" await self.kvstore.set(key=meta_key, value=json.dumps(file_info)) - contents_prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:" + contents_prefix = ( + f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:" + ) for idx, chunk in enumerate(file_contents): - await self.kvstore.set(key=f"{contents_prefix}{idx}", value=json.dumps(chunk)) + await self.kvstore.set( + key=f"{contents_prefix}{idx}", value=json.dumps(chunk) + ) - async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: + async def _load_openai_vector_store_file( + self, store_id: str, file_id: str + ) -> dict[str, Any]: """Load vector store file metadata from persistent storage.""" assert self.kvstore key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" stored_data = await self.kvstore.get(key) return json.loads(stored_data) if stored_data else {} - async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: + async def _load_openai_vector_store_file_contents( + self, store_id: str, file_id: str + ) -> list[dict[str, Any]]: """Load vector store file contents from persistent storage.""" assert self.kvstore prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:" @@ -138,20 +153,26 @@ class OpenAIVectorStoreMixin(ABC): raw_items = await self.kvstore.values_in_range(prefix, end_key) return [json.loads(item) for item in raw_items] - async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: + async def _update_openai_vector_store_file( + self, store_id: str, file_id: str, file_info: dict[str, Any] + ) -> None: """Update vector store file metadata in persistent storage.""" assert self.kvstore key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" await self.kvstore.set(key=key, value=json.dumps(file_info)) - async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: + async def _delete_openai_vector_store_file_from_storage( + self, store_id: str, file_id: str + ) -> None: """Delete vector store file metadata from persistent storage.""" assert self.kvstore meta_key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" await self.kvstore.delete(meta_key) - contents_prefix = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:" + contents_prefix = ( + f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}:" + ) end_key = f"{contents_prefix}\xff" # load all stored chunk values (values_in_range is implemented by all backends) raw_items = await self.kvstore.values_in_range(contents_prefix, end_key) @@ -164,7 +185,9 @@ class OpenAIVectorStoreMixin(ABC): self.openai_vector_stores = await self._load_openai_vector_stores() @abstractmethod - async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: + async def delete_chunks( + self, store_id: str, chunks_for_deletion: list[ChunkForDeletion] + ) -> None: """Delete chunks from a vector store.""" pass @@ -275,7 +298,10 @@ class OpenAIVectorStoreMixin(ABC): # Now that our vector store is created, attach any files that were provided file_ids = file_ids or [] - tasks = [self.openai_attach_file_to_vector_store(vector_db_id, file_id) for file_id in file_ids] + tasks = [ + self.openai_attach_file_to_vector_store(vector_db_id, file_id) + for file_id in file_ids + ] await asyncio.gather(*tasks) # Get the updated store info and return it @@ -302,7 +328,9 @@ class OpenAIVectorStoreMixin(ABC): # Apply cursor-based pagination if after: - after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1) + after_index = next( + (i for i, store in enumerate(all_stores) if store["id"] == after), -1 + ) if after_index >= 0: all_stores = all_stores[after_index + 1 :] @@ -391,7 +419,9 @@ class OpenAIVectorStoreMixin(ABC): try: await self.unregister_vector_db(vector_store_id) except Exception as e: - logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") + logger.warning( + f"Failed to delete underlying vector DB {vector_store_id}: {e}" + ) return VectorStoreDeleteResponse( id=vector_store_id, @@ -416,7 +446,9 @@ class OpenAIVectorStoreMixin(ABC): # Validate search_mode valid_modes = {"keyword", "vector", "hybrid"} if search_mode not in valid_modes: - raise ValueError(f"search_mode must be one of {valid_modes}, got {search_mode}") + raise ValueError( + f"search_mode must be one of {valid_modes}, got {search_mode}" + ) if vector_store_id not in self.openai_vector_stores: raise VectorStoreNotFoundError(vector_store_id) @@ -484,7 +516,9 @@ class OpenAIVectorStoreMixin(ABC): next_page=None, ) - def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: + def _matches_filters( + self, metadata: dict[str, Any], filters: dict[str, Any] + ) -> bool: """Check if metadata matches the provided filters.""" if not filters: return True @@ -604,7 +638,9 @@ class OpenAIVectorStoreMixin(ABC): try: file_response = await self.files_api.openai_retrieve_file(file_id) mime_type, _ = mimetypes.guess_type(file_response.filename) - content_response = await self.files_api.openai_retrieve_file_content(file_id) + content_response = await self.files_api.openai_retrieve_file_content( + file_id + ) content = content_from_data_and_mime_type(content_response.body, mime_type) @@ -643,7 +679,9 @@ class OpenAIVectorStoreMixin(ABC): # Save vector store file to persistent storage (provider-specific) dict_chunks = [c.model_dump() for c in chunks] # This should be updated to include chunk_id - await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks) + await self._save_openai_vector_store_file( + vector_store_id, file_id, file_info, dict_chunks + ) # Update file_ids and file_counts in vector store metadata store_info = self.openai_vector_stores[vector_store_id].copy() @@ -679,7 +717,9 @@ class OpenAIVectorStoreMixin(ABC): file_objects: list[VectorStoreFileObject] = [] for file_id in store_info["file_ids"]: - file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) + file_info = await self._load_openai_vector_store_file( + vector_store_id, file_id + ) file_object = VectorStoreFileObject(**file_info) if filter and file_object.status != filter: continue @@ -691,7 +731,9 @@ class OpenAIVectorStoreMixin(ABC): # Apply cursor-based pagination if after: - after_index = next((i for i, file in enumerate(file_objects) if file.id == after), -1) + after_index = next( + (i for i, file in enumerate(file_objects) if file.id == after), -1 + ) if after_index >= 0: file_objects = file_objects[after_index + 1 :] @@ -728,7 +770,9 @@ class OpenAIVectorStoreMixin(ABC): store_info = self.openai_vector_stores[vector_store_id] if file_id not in store_info["file_ids"]: - raise ValueError(f"File {file_id} not found in vector store {vector_store_id}") + raise ValueError( + f"File {file_id} not found in vector store {vector_store_id}" + ) file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) return VectorStoreFileObject(**file_info) @@ -743,7 +787,9 @@ class OpenAIVectorStoreMixin(ABC): raise VectorStoreNotFoundError(vector_store_id) file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) - dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) + dict_chunks = await self._load_openai_vector_store_file_contents( + vector_store_id, file_id + ) chunks = [Chunk.model_validate(c) for c in dict_chunks] content = [] for chunk in chunks: @@ -767,7 +813,9 @@ class OpenAIVectorStoreMixin(ABC): store_info = self.openai_vector_stores[vector_store_id] if file_id not in store_info["file_ids"]: - raise ValueError(f"File {file_id} not found in vector store {vector_store_id}") + raise ValueError( + f"File {file_id} not found in vector store {vector_store_id}" + ) file_info = await self._load_openai_vector_store_file(vector_store_id, file_id) file_info["attributes"] = attributes @@ -783,7 +831,9 @@ class OpenAIVectorStoreMixin(ABC): if vector_store_id not in self.openai_vector_stores: raise VectorStoreNotFoundError(vector_store_id) - dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id) + dict_chunks = await self._load_openai_vector_store_file_contents( + vector_store_id, file_id + ) chunks = [Chunk.model_validate(c) for c in dict_chunks] # Create ChunkForDeletion objects with both chunk_id and document_id @@ -794,9 +844,15 @@ class OpenAIVectorStoreMixin(ABC): c.chunk_metadata.document_id if c.chunk_metadata else None ) if document_id: - chunks_for_deletion.append(ChunkForDeletion(chunk_id=str(c.chunk_id), document_id=document_id)) + chunks_for_deletion.append( + ChunkForDeletion( + chunk_id=str(c.chunk_id), document_id=document_id + ) + ) else: - logger.warning(f"Chunk {c.chunk_id} has no document_id, skipping deletion") + logger.warning( + f"Chunk {c.chunk_id} has no document_id, skipping deletion" + ) if chunks_for_deletion: await self.delete_chunks(vector_store_id, chunks_for_deletion) @@ -804,7 +860,9 @@ class OpenAIVectorStoreMixin(ABC): store_info = self.openai_vector_stores[vector_store_id].copy() file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id) - await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id) + await self._delete_openai_vector_store_file_from_storage( + vector_store_id, file_id + ) # Update in-memory cache store_info["file_ids"].remove(file_id) @@ -828,7 +886,156 @@ class OpenAIVectorStoreMixin(ABC): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileBatchObject: """Create a vector store file batch.""" - raise NotImplementedError("openai_create_vector_store_file_batch is not implemented yet") + if vector_store_id not in self.openai_vector_stores: + raise VectorStoreNotFoundError(vector_store_id) + + chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto() + + created_at = int(time.time()) + batch_id = f"batch_{uuid.uuid4()}" + + # Initialize batch file counts - all files start as in_progress + file_counts = VectorStoreFileCounts( + completed=0, + cancelled=0, + failed=0, + in_progress=len(file_ids), + total=len(file_ids), + ) + + # Create batch object immediately with in_progress status + batch_object = VectorStoreFileBatchObject( + id=batch_id, + created_at=created_at, + vector_store_id=vector_store_id, + status="in_progress", + file_counts=file_counts, + ) + + # Store batch object and file_ids in memory + self.openai_file_batches[batch_id] = { + "batch_object": batch_object, + "file_ids": file_ids, + } + + # Start background processing of files + asyncio.create_task( + self._process_file_batch_async( + batch_id, file_ids, attributes, chunking_strategy + ) + ) + + return batch_object + + async def _process_file_batch_async( + self, + batch_id: str, + file_ids: list[str], + attributes: dict[str, Any] | None, + chunking_strategy: VectorStoreChunkingStrategy | None, + ) -> None: + """Process files in a batch asynchronously in the background.""" + batch_info = self.openai_file_batches[batch_id] + batch_object = batch_info["batch_object"] + vector_store_id = batch_object.vector_store_id + + for file_id in file_ids: + try: + # Process each file + await self.openai_attach_file_to_vector_store( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + # Update counts atomically + batch_object.file_counts.completed += 1 + batch_object.file_counts.in_progress -= 1 + + except Exception as e: + logger.error( + f"Failed to process file {file_id} in batch {batch_id}: {e}" + ) + batch_object.file_counts.failed += 1 + batch_object.file_counts.in_progress -= 1 + + # Update final status when all files are processed + if batch_object.file_counts.failed == 0: + batch_object.status = "completed" + elif batch_object.file_counts.completed == 0: + batch_object.status = "failed" + else: + batch_object.status = "completed" # Partial success counts as completed + + logger.info( + f"File batch {batch_id} processing completed with status: {batch_object.status}" + ) + + def _get_and_validate_batch( + self, batch_id: str, vector_store_id: str + ) -> tuple[dict[str, Any], VectorStoreFileBatchObject]: + """Get and validate batch exists and belongs to vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise VectorStoreNotFoundError(vector_store_id) + + if batch_id not in self.openai_file_batches: + raise ValueError(f"File batch {batch_id} not found") + + batch_info = self.openai_file_batches[batch_id] + batch_object = batch_info["batch_object"] + + if batch_object.vector_store_id != vector_store_id: + raise ValueError( + f"File batch {batch_id} does not belong to vector store {vector_store_id}" + ) + + return batch_info, batch_object + + def _paginate_objects( + self, + objects: list[Any], + limit: int | None = 20, + after: str | None = None, + before: str | None = None, + ) -> tuple[list[Any], bool, str | None, str | None]: + """Apply pagination to a list of objects with id fields.""" + limit = min(limit or 20, 100) # Cap at 100 as per OpenAI + + # Find start index + start_idx = 0 + if after: + for i, obj in enumerate(objects): + if obj.id == after: + start_idx = i + 1 + break + + # Find end index + end_idx = start_idx + limit + if before: + for i, obj in enumerate(objects[start_idx:], start_idx): + if obj.id == before: + end_idx = i + break + + # Apply pagination + paginated_objects = objects[start_idx:end_idx] + + # Determine pagination info + has_more = end_idx < len(objects) + first_id = paginated_objects[0].id if paginated_objects else None + last_id = paginated_objects[-1].id if paginated_objects else None + + return paginated_objects, has_more, first_id, last_id + + async def openai_retrieve_vector_store_file_batch( + self, + batch_id: str, + vector_store_id: str, + ) -> VectorStoreFileBatchObject: + """Retrieve a vector store file batch.""" + _, batch_object = self._get_and_validate_batch(batch_id, vector_store_id) + return batch_object async def openai_list_files_in_vector_store_file_batch( self, @@ -841,15 +1048,45 @@ class OpenAIVectorStoreMixin(ABC): order: str | None = "desc", ) -> VectorStoreFilesListInBatchResponse: """Returns a list of vector store files in a batch.""" - raise NotImplementedError("openai_list_files_in_vector_store_file_batch is not implemented yet") + batch_info, _ = self._get_and_validate_batch(batch_id, vector_store_id) + batch_file_ids = batch_info["file_ids"] - async def openai_retrieve_vector_store_file_batch( - self, - batch_id: str, - vector_store_id: str, - ) -> VectorStoreFileBatchObject: - """Retrieve a vector store file batch.""" - raise NotImplementedError("openai_retrieve_vector_store_file_batch is not implemented yet") + # Load file objects for files in this batch + batch_file_objects = [] + + for file_id in batch_file_ids: + try: + file_info = await self._load_openai_vector_store_file( + vector_store_id, file_id + ) + file_object = VectorStoreFileObject(**file_info) + + # Apply status filter if provided + if filter and file_object.status != filter: + continue + + batch_file_objects.append(file_object) + except Exception as e: + logger.warning( + f"Could not load file {file_id} from batch {batch_id}: {e}" + ) + continue + + # Sort by created_at + reverse_order = order == "desc" + batch_file_objects.sort(key=lambda x: x.created_at, reverse=reverse_order) + + # Apply pagination using helper + paginated_files, has_more, first_id, last_id = self._paginate_objects( + batch_file_objects, limit, after, before + ) + + return VectorStoreFilesListInBatchResponse( + data=paginated_files, + first_id=first_id, + last_id=last_id, + has_more=has_more, + ) async def openai_cancel_vector_store_file_batch( self, @@ -857,4 +1094,28 @@ class OpenAIVectorStoreMixin(ABC): vector_store_id: str, ) -> VectorStoreFileBatchObject: """Cancel a vector store file batch.""" - raise NotImplementedError("openai_cancel_vector_store_file_batch is not implemented yet") + batch_info, batch_object = self._get_and_validate_batch( + batch_id, vector_store_id + ) + + # Only allow cancellation if batch is in progress + if batch_object.status not in ["in_progress"]: + raise ValueError( + f"Cannot cancel batch {batch_id} with status {batch_object.status}" + ) + + # Create updated batch object with cancelled status + updated_batch = VectorStoreFileBatchObject( + id=batch_object.id, + object=batch_object.object, + created_at=batch_object.created_at, + vector_store_id=batch_object.vector_store_id, + status="cancelled", + file_counts=batch_object.file_counts, + ) + + # Update the stored batch info + batch_info["batch_object"] = updated_batch + self.openai_file_batches[batch_id] = batch_info + + return updated_batch diff --git a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py index 98889f38e..c6f84c906 100644 --- a/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py +++ b/tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py @@ -11,11 +11,12 @@ from unittest.mock import AsyncMock import numpy as np import pytest +from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX -# This test is a unit test for the inline VectoerIO providers. This should only contain +# This test is a unit test for the inline VectorIO providers. This should only contain # tests which are specific to this class. More general (API-level) tests should be placed in # tests/integration/vector_io/ # @@ -294,3 +295,347 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t assert loaded_file_info == {} loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) assert loaded_contents == [] + + +async def test_create_vector_store_file_batch(vector_io_adapter): + """Test creating a file batch.""" + store_id = "vs_1234" + file_ids = ["file_1", "file_2", "file_3"] + + # Setup vector store + vector_io_adapter.openai_vector_stores[store_id] = { + "id": store_id, + "name": "Test Store", + "files": {}, + "file_ids": [], + } + + # Mock attach method to avoid actual processing + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock(return_value={"status": "completed"}) + + batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=file_ids, + ) + + assert batch.vector_store_id == store_id + assert batch.status == "in_progress" + assert batch.file_counts.total == len(file_ids) + assert batch.file_counts.in_progress == len(file_ids) + assert batch.id in vector_io_adapter.openai_file_batches + + +async def test_retrieve_vector_store_file_batch(vector_io_adapter): + """Test retrieving a file batch.""" + store_id = "vs_1234" + file_ids = ["file_1", "file_2"] + + # Setup vector store + vector_io_adapter.openai_vector_stores[store_id] = { + "id": store_id, + "name": "Test Store", + "files": {}, + "file_ids": [], + } + + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock() + + # Create batch first + created_batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=file_ids, + ) + + # Retrieve batch + retrieved_batch = await vector_io_adapter.openai_retrieve_vector_store_file_batch( + batch_id=created_batch.id, + vector_store_id=store_id, + ) + + assert retrieved_batch.id == created_batch.id + assert retrieved_batch.vector_store_id == store_id + assert retrieved_batch.status == "in_progress" + + +async def test_cancel_vector_store_file_batch(vector_io_adapter): + """Test cancelling a file batch.""" + store_id = "vs_1234" + file_ids = ["file_1"] + + # Setup vector store + vector_io_adapter.openai_vector_stores[store_id] = { + "id": store_id, + "name": "Test Store", + "files": {}, + "file_ids": [], + } + + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock() + + # Create batch + batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=file_ids, + ) + + # Cancel batch + cancelled_batch = await vector_io_adapter.openai_cancel_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + ) + + assert cancelled_batch.status == "cancelled" + + +async def test_list_files_in_vector_store_file_batch(vector_io_adapter): + """Test listing files in a batch.""" + store_id = "vs_1234" + file_ids = ["file_1", "file_2"] + + # Setup vector store with files + from llama_stack.apis.vector_io import VectorStoreChunkingStrategyAuto, VectorStoreFileObject + + files = {} + for i, file_id in enumerate(file_ids): + files[file_id] = VectorStoreFileObject( + id=file_id, + object="vector_store.file", + usage_bytes=1000, + created_at=int(time.time()) + i, + vector_store_id=store_id, + status="completed", + chunking_strategy=VectorStoreChunkingStrategyAuto(), + ) + + vector_io_adapter.openai_vector_stores[store_id] = { + "id": store_id, + "name": "Test Store", + "files": files, + "file_ids": file_ids, + } + + # Mock file loading + async def mock_load_file(vs_id, f_id): + return files[f_id].model_dump() + + vector_io_adapter._load_openai_vector_store_file = mock_load_file + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock() + + # Create batch + batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=file_ids, + ) + + # List files + response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + ) + + assert len(response.data) == len(file_ids) + assert response.first_id is not None + assert response.last_id is not None + + +async def test_file_batch_validation_errors(vector_io_adapter): + """Test file batch validation errors.""" + # Test nonexistent vector store + with pytest.raises(VectorStoreNotFoundError): + await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id="nonexistent", + file_ids=["file_1"], + ) + + # Setup store for remaining tests + store_id = "vs_test" + vector_io_adapter.openai_vector_stores[store_id] = {"id": store_id, "files": {}, "file_ids": []} + + # Test nonexistent batch + with pytest.raises(ValueError, match="File batch .* not found"): + await vector_io_adapter.openai_retrieve_vector_store_file_batch( + batch_id="nonexistent_batch", + vector_store_id=store_id, + ) + + # Test wrong vector store for batch + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock() + batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=["file_1"], + ) + + # Create wrong_store so it exists but the batch doesn't belong to it + wrong_store_id = "wrong_store" + vector_io_adapter.openai_vector_stores[wrong_store_id] = {"id": wrong_store_id, "files": {}, "file_ids": []} + + with pytest.raises(ValueError, match="does not belong to vector store"): + await vector_io_adapter.openai_retrieve_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=wrong_store_id, + ) + + +async def test_file_batch_pagination(vector_io_adapter): + """Test file batch pagination.""" + store_id = "vs_1234" + file_ids = ["file_1", "file_2", "file_3", "file_4", "file_5"] + + # Setup vector store with multiple files + from llama_stack.apis.vector_io import VectorStoreChunkingStrategyAuto, VectorStoreFileObject + + files = {} + for i, file_id in enumerate(file_ids): + files[file_id] = VectorStoreFileObject( + id=file_id, + object="vector_store.file", + usage_bytes=1000, + created_at=int(time.time()) + i, + vector_store_id=store_id, + status="completed", + chunking_strategy=VectorStoreChunkingStrategyAuto(), + ) + + vector_io_adapter.openai_vector_stores[store_id] = { + "id": store_id, + "name": "Test Store", + "files": files, + "file_ids": file_ids, + } + + # Mock file loading + async def mock_load_file(vs_id, f_id): + return files[f_id].model_dump() + + vector_io_adapter._load_openai_vector_store_file = mock_load_file + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock() + + # Create batch + batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=file_ids, + ) + + # Test pagination with limit + response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + limit=3, + ) + + assert len(response.data) == 3 + assert response.has_more is True + + # Test pagination with after cursor + first_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + limit=2, + ) + + second_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + limit=2, + after=first_page.last_id, + ) + + assert len(first_page.data) == 2 + assert len(second_page.data) == 2 + assert first_page.data[0].id != second_page.data[0].id + + +async def test_file_batch_status_filtering(vector_io_adapter): + """Test file batch status filtering.""" + store_id = "vs_1234" + file_ids = ["file_1", "file_2", "file_3"] + + # Setup vector store with files having different statuses + from llama_stack.apis.vector_io import VectorStoreChunkingStrategyAuto, VectorStoreFileObject + + files = {} + statuses = ["completed", "in_progress", "completed"] + for i, (file_id, status) in enumerate(zip(file_ids, statuses, strict=False)): + files[file_id] = VectorStoreFileObject( + id=file_id, + object="vector_store.file", + usage_bytes=1000, + created_at=int(time.time()) + i, + vector_store_id=store_id, + status=status, + chunking_strategy=VectorStoreChunkingStrategyAuto(), + ) + + vector_io_adapter.openai_vector_stores[store_id] = { + "id": store_id, + "name": "Test Store", + "files": files, + "file_ids": file_ids, + } + + # Mock file loading + async def mock_load_file(vs_id, f_id): + return files[f_id].model_dump() + + vector_io_adapter._load_openai_vector_store_file = mock_load_file + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock() + + # Create batch + batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=file_ids, + ) + + # Test filtering by completed status + response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + filter="completed", + ) + + assert len(response.data) == 2 # Only 2 completed files + for file_obj in response.data: + assert file_obj.status == "completed" + + # Test filtering by in_progress status + response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + filter="in_progress", + ) + + assert len(response.data) == 1 # Only 1 in_progress file + assert response.data[0].status == "in_progress" + + +async def test_cancel_completed_batch_fails(vector_io_adapter): + """Test that cancelling completed batch fails.""" + store_id = "vs_1234" + file_ids = ["file_1"] + + # Setup vector store + vector_io_adapter.openai_vector_stores[store_id] = { + "id": store_id, + "name": "Test Store", + "files": {}, + "file_ids": [], + } + + vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock() + + # Create batch + batch = await vector_io_adapter.openai_create_vector_store_file_batch( + vector_store_id=store_id, + file_ids=file_ids, + ) + + # Manually update status to completed + batch_info = vector_io_adapter.openai_file_batches[batch.id] + batch_info["batch_object"].status = "completed" + + # Try to cancel - should fail + with pytest.raises(ValueError, match="Cannot cancel batch .* with status completed"): + await vector_io_adapter.openai_cancel_vector_store_file_batch( + batch_id=batch.id, + vector_store_id=store_id, + )