This commit is contained in:
slekkala1 2025-10-03 12:32:33 -07:00 committed by GitHub
commit 6ffecd5507
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1343 additions and 41 deletions

View file

@ -245,3 +245,65 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -200,12 +200,10 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None: def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)

View file

@ -410,12 +410,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
""" """
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None: def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)

View file

@ -140,14 +140,13 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference, inference_api: Api.inference,
files_api: Files | None, files_api: Files | None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.client = None self.client = None
self.cache = {} self.cache = {}
self.kvstore: KVStore | None = None
self.vector_db_store = None self.vector_db_store = None
self.files_api = files_api
async def initialize(self) -> None: async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore) self.kvstore = await kvstore_impl(self.config.kvstore)

View file

@ -309,14 +309,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Inference, inference_api: Inference,
files_api: Files | None, files_api: Files | None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.cache = {} self.cache = {}
self.client = None self.client = None
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -345,14 +345,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
inference_api: Api.inference, inference_api: Api.inference,
files_api: Files | None = None, files_api: Files | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.conn = None self.conn = None
self.cache = {} self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -27,7 +27,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ChunkForDeletion, ChunkForDeletion,
@ -162,14 +162,12 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
inference_api: Api.inference, inference_api: Api.inference,
files_api: Files | None = None, files_api: Files | None = None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.client: AsyncQdrantClient = None self.client: AsyncQdrantClient = None
self.cache = {} self.cache = {}
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
self.vector_db_store = None self.vector_db_store = None
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self._qdrant_lock = asyncio.Lock() self._qdrant_lock = asyncio.Lock()
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -284,14 +284,12 @@ class WeaviateVectorIOAdapter(
inference_api: Api.inference, inference_api: Api.inference,
files_api: Files | None, files_api: Files | None,
) -> None: ) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata" self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.WeaviateClient: def _get_client(self) -> weaviate.WeaviateClient:

View file

@ -12,6 +12,8 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from pydantic import TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files, OpenAIFileObject from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
@ -50,12 +52,16 @@ logger = get_logger(name=__name__, category="providers::utils")
# Constants for OpenAI vector stores # Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5 CHUNK_MULTIPLIER = 5
FILE_BATCH_CLEANUP_INTERVAL_SECONDS = 24 * 60 * 60 # 1 day in seconds
MAX_CONCURRENT_FILES_PER_BATCH = 5 # Maximum concurrent file processing within a batch
FILE_BATCH_CHUNK_SIZE = 10 # Process files in chunks of this size (2x concurrency)
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::" OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::" OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX = f"openai_vector_stores_file_batches:{VERSION}::"
class OpenAIVectorStoreMixin(ABC): class OpenAIVectorStoreMixin(ABC):
@ -65,11 +71,15 @@ class OpenAIVectorStoreMixin(ABC):
an openai_vector_stores in-memory cache. an openai_vector_stores in-memory cache.
""" """
# These should be provided by the implementing class # Implementing classes should call super().__init__() in their __init__ method
openai_vector_stores: dict[str, dict[str, Any]] # to properly initialize the mixin attributes.
files_api: Files | None def __init__(self, files_api: Files | None = None, kvstore: KVStore | None = None):
# KV store for persisting OpenAI vector store metadata self.openai_vector_stores: dict[str, dict[str, Any]] = {}
kvstore: KVStore | None self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.files_api = files_api
self.kvstore = kvstore
self._last_file_batch_cleanup_time = 0
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage.""" """Save vector store metadata to persistent storage."""
@ -159,9 +169,74 @@ class OpenAIVectorStoreMixin(ABC):
for idx in range(len(raw_items)): for idx in range(len(raw_items)):
await self.kvstore.delete(f"{contents_prefix}{idx}") await self.kvstore.delete(f"{contents_prefix}{idx}")
async def _save_openai_vector_store_file_batch(self, batch_id: str, batch_info: dict[str, Any]) -> None:
"""Save file batch metadata to persistent storage."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{batch_id}"
await self.kvstore.set(key=key, value=json.dumps(batch_info))
# update in-memory cache
self.openai_file_batches[batch_id] = batch_info
async def _load_openai_vector_store_file_batches(self) -> dict[str, dict[str, Any]]:
"""Load all file batch metadata from persistent storage."""
assert self.kvstore
start_key = OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
batches: dict[str, dict[str, Any]] = {}
for item in stored_data:
info = json.loads(item)
batches[info["id"]] = info
return batches
async def _delete_openai_vector_store_file_batch(self, batch_id: str) -> None:
"""Delete file batch metadata from persistent storage and in-memory cache."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{batch_id}"
await self.kvstore.delete(key)
# remove from in-memory cache
self.openai_file_batches.pop(batch_id, None)
async def _cleanup_expired_file_batches(self) -> None:
"""Clean up expired file batches from persistent storage."""
assert self.kvstore
start_key = OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
current_time = int(time.time())
expired_count = 0
for item in stored_data:
info = json.loads(item)
expires_at = info.get("expires_at")
if expires_at and current_time > expires_at:
logger.info(f"Cleaning up expired file batch: {info['id']}")
await self.kvstore.delete(f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{info['id']}")
# Remove from in-memory cache if present
self.openai_file_batches.pop(info["id"], None)
expired_count += 1
if expired_count > 0:
logger.info(f"Cleaned up {expired_count} expired file batches")
async def _resume_incomplete_batches(self) -> None:
"""Resume processing of incomplete file batches after server restart."""
for batch_id, batch_info in self.openai_file_batches.items():
if batch_info["status"] == "in_progress":
logger.info(f"Resuming incomplete file batch: {batch_id}")
# Restart the background processing task
task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
self._file_batch_tasks[batch_id] = task
async def initialize_openai_vector_stores(self) -> None: async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores into the in-memory cache.""" """Load existing OpenAI vector stores and file batches into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores() self.openai_vector_stores = await self._load_openai_vector_stores()
self.openai_file_batches = await self._load_openai_vector_store_file_batches()
self._file_batch_tasks = {}
await self._resume_incomplete_batches()
self._last_file_batch_cleanup_time = 0
@abstractmethod @abstractmethod
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None: async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
@ -615,7 +690,6 @@ class OpenAIVectorStoreMixin(ABC):
chunk_overlap_tokens, chunk_overlap_tokens,
attributes, attributes,
) )
if not chunks: if not chunks:
vector_store_file_object.status = "failed" vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError( vector_store_file_object.last_error = VectorStoreFileLastError(
@ -828,7 +902,227 @@ class OpenAIVectorStoreMixin(ABC):
chunking_strategy: VectorStoreChunkingStrategy | None = None, chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileBatchObject: ) -> VectorStoreFileBatchObject:
"""Create a vector store file batch.""" """Create a vector store file batch."""
raise NotImplementedError("openai_create_vector_store_file_batch is not implemented yet") if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id)
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
created_at = int(time.time())
batch_id = f"batch_{uuid.uuid4()}"
# File batches expire after 7 days
expires_at = created_at + (7 * 24 * 60 * 60)
# Initialize batch file counts - all files start as in_progress
file_counts = VectorStoreFileCounts(
completed=0,
cancelled=0,
failed=0,
in_progress=len(file_ids),
total=len(file_ids),
)
# Create batch object immediately with in_progress status
batch_object = VectorStoreFileBatchObject(
id=batch_id,
created_at=created_at,
vector_store_id=vector_store_id,
status="in_progress",
file_counts=file_counts,
)
batch_info = {
**batch_object.model_dump(),
"file_ids": file_ids,
"attributes": attributes,
"chunking_strategy": chunking_strategy.model_dump(),
"expires_at": expires_at,
}
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
# Start background processing of files
task = asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
self._file_batch_tasks[batch_id] = task
# Run cleanup if needed (throttled to once every 1 day)
current_time = int(time.time())
if current_time - self._last_file_batch_cleanup_time >= FILE_BATCH_CLEANUP_INTERVAL_SECONDS:
logger.info("Running throttled cleanup of expired file batches")
asyncio.create_task(self._cleanup_expired_file_batches())
self._last_file_batch_cleanup_time = current_time
return batch_object
async def _process_files_with_concurrency(
self,
file_ids: list[str],
vector_store_id: str,
attributes: dict[str, Any],
chunking_strategy_obj: Any,
batch_id: str,
batch_info: dict[str, Any],
) -> None:
"""Process files with controlled concurrency and chunking."""
semaphore = asyncio.Semaphore(MAX_CONCURRENT_FILES_PER_BATCH)
async def process_single_file(file_id: str) -> tuple[str, bool]:
"""Process a single file with concurrency control."""
async with semaphore:
try:
await self.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy_obj,
)
return file_id, True
except Exception as e:
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
return file_id, False
# Process files in chunks to avoid creating too many tasks at once
total_files = len(file_ids)
for chunk_start in range(0, total_files, FILE_BATCH_CHUNK_SIZE):
chunk_end = min(chunk_start + FILE_BATCH_CHUNK_SIZE, total_files)
chunk = file_ids[chunk_start:chunk_end]
logger.info(
f"Processing chunk {chunk_start // FILE_BATCH_CHUNK_SIZE + 1} of {(total_files + FILE_BATCH_CHUNK_SIZE - 1) // FILE_BATCH_CHUNK_SIZE} ({len(chunk)} files)"
)
async with asyncio.TaskGroup() as tg:
chunk_tasks = [tg.create_task(process_single_file(file_id)) for file_id in chunk]
chunk_results = [task.result() for task in chunk_tasks]
# Update counts after each chunk for progressive feedback
for _, success in chunk_results:
self._update_file_counts(batch_info, success=success)
# Save progress after each chunk
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
def _update_file_counts(self, batch_info: dict[str, Any], success: bool) -> None:
"""Update file counts based on processing result."""
if success:
batch_info["file_counts"]["completed"] += 1
else:
batch_info["file_counts"]["failed"] += 1
batch_info["file_counts"]["in_progress"] -= 1
def _update_batch_status(self, batch_info: dict[str, Any]) -> None:
"""Update final batch status based on file processing results."""
if batch_info["file_counts"]["failed"] == 0:
batch_info["status"] = "completed"
elif batch_info["file_counts"]["completed"] == 0:
batch_info["status"] = "failed"
else:
batch_info["status"] = "completed" # Partial success counts as completed
async def _process_file_batch_async(
self,
batch_id: str,
batch_info: dict[str, Any],
) -> None:
"""Process files in a batch asynchronously in the background."""
file_ids = batch_info["file_ids"]
attributes = batch_info["attributes"]
chunking_strategy = batch_info["chunking_strategy"]
vector_store_id = batch_info["vector_store_id"]
chunking_strategy_adapter: TypeAdapter[VectorStoreChunkingStrategy] = TypeAdapter(VectorStoreChunkingStrategy)
chunking_strategy_obj = chunking_strategy_adapter.validate_python(chunking_strategy)
try:
# Process all files with controlled concurrency
await self._process_files_with_concurrency(
file_ids=file_ids,
vector_store_id=vector_store_id,
attributes=attributes,
chunking_strategy_obj=chunking_strategy_obj,
batch_id=batch_id,
batch_info=batch_info,
)
# Update final batch status
self._update_batch_status(batch_info)
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
logger.info(f"File batch {batch_id} processing completed with status: {batch_info['status']}")
except asyncio.CancelledError:
logger.info(f"File batch {batch_id} processing was cancelled")
# Clean up task reference if it still exists
self._file_batch_tasks.pop(batch_id, None)
raise # Re-raise to ensure proper cancellation propagation
finally:
# Always clean up task reference when processing ends
self._file_batch_tasks.pop(batch_id, None)
def _get_and_validate_batch(self, batch_id: str, vector_store_id: str) -> dict[str, Any]:
"""Get and validate batch exists and belongs to vector store."""
if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id)
if batch_id not in self.openai_file_batches:
raise ValueError(f"File batch {batch_id} not found")
batch_info = self.openai_file_batches[batch_id]
# Check if batch has expired (read-only check)
expires_at = batch_info.get("expires_at")
if expires_at:
current_time = int(time.time())
if current_time > expires_at:
raise ValueError(f"File batch {batch_id} has expired after 7 days from creation")
if batch_info["vector_store_id"] != vector_store_id:
raise ValueError(f"File batch {batch_id} does not belong to vector store {vector_store_id}")
return batch_info
def _paginate_objects(
self,
objects: list[Any],
limit: int | None = 20,
after: str | None = None,
before: str | None = None,
) -> tuple[list[Any], bool, str | None, str | None]:
"""Apply pagination to a list of objects with id fields."""
limit = min(limit or 20, 100) # Cap at 100 as per OpenAI
# Find start index
start_idx = 0
if after:
for i, obj in enumerate(objects):
if obj.id == after:
start_idx = i + 1
break
# Find end index
end_idx = start_idx + limit
if before:
for i, obj in enumerate(objects[start_idx:], start_idx):
if obj.id == before:
end_idx = i
break
# Apply pagination
paginated_objects = objects[start_idx:end_idx]
# Determine pagination info
has_more = end_idx < len(objects)
first_id = paginated_objects[0].id if paginated_objects else None
last_id = paginated_objects[-1].id if paginated_objects else None
return paginated_objects, has_more, first_id, last_id
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Retrieve a vector store file batch."""
batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
return VectorStoreFileBatchObject(**batch_info)
async def openai_list_files_in_vector_store_file_batch( async def openai_list_files_in_vector_store_file_batch(
self, self,
@ -841,15 +1135,39 @@ class OpenAIVectorStoreMixin(ABC):
order: str | None = "desc", order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse: ) -> VectorStoreFilesListInBatchResponse:
"""Returns a list of vector store files in a batch.""" """Returns a list of vector store files in a batch."""
raise NotImplementedError("openai_list_files_in_vector_store_file_batch is not implemented yet") batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
batch_file_ids = batch_info["file_ids"]
async def openai_retrieve_vector_store_file_batch( # Load file objects for files in this batch
self, batch_file_objects = []
batch_id: str,
vector_store_id: str, for file_id in batch_file_ids:
) -> VectorStoreFileBatchObject: try:
"""Retrieve a vector store file batch.""" file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
raise NotImplementedError("openai_retrieve_vector_store_file_batch is not implemented yet") file_object = VectorStoreFileObject(**file_info)
# Apply status filter if provided
if filter and file_object.status != filter:
continue
batch_file_objects.append(file_object)
except Exception as e:
logger.warning(f"Could not load file {file_id} from batch {batch_id}: {e}")
continue
# Sort by created_at
reverse_order = order == "desc"
batch_file_objects.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply pagination using helper
paginated_files, has_more, first_id, last_id = self._paginate_objects(batch_file_objects, limit, after, before)
return VectorStoreFilesListInBatchResponse(
data=paginated_files,
first_id=first_id,
last_id=last_id,
has_more=has_more,
)
async def openai_cancel_vector_store_file_batch( async def openai_cancel_vector_store_file_batch(
self, self,
@ -857,4 +1175,24 @@ class OpenAIVectorStoreMixin(ABC):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreFileBatchObject: ) -> VectorStoreFileBatchObject:
"""Cancel a vector store file batch.""" """Cancel a vector store file batch."""
raise NotImplementedError("openai_cancel_vector_store_file_batch is not implemented yet") batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
if batch_info["status"] not in ["in_progress"]:
raise ValueError(f"Cannot cancel batch {batch_id} with status {batch_info['status']}")
# Cancel the actual processing task if it exists
if batch_id in self._file_batch_tasks:
task = self._file_batch_tasks[batch_id]
if not task.done():
task.cancel()
logger.info(f"Cancelled processing task for file batch: {batch_id}")
# Remove from task tracking
del self._file_batch_tasks[batch_id]
batch_info["status"] = "cancelled"
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
updated_batch = VectorStoreFileBatchObject(**batch_info)
return updated_batch

View file

@ -902,3 +902,224 @@ def test_openai_vector_store_search_modes(llama_stack_client, client_with_models
search_mode=search_mode, search_mode=search_mode,
) )
assert search_response is not None assert search_response is not None
def test_openai_vector_store_file_batch_create_and_retrieve(compat_client_with_empty_stores, client_with_models):
"""Test creating and retrieving a vector store file batch."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
compat_client = compat_client_with_empty_stores
# Create a vector store
vector_store = compat_client.vector_stores.create(name="batch_test_store")
# Create multiple files
file_ids = []
for i in range(3):
with BytesIO(f"This is batch test file {i}".encode()) as file_buffer:
file_buffer.name = f"batch_test_{i}.txt"
file = compat_client.files.create(file=file_buffer, purpose="assistants")
file_ids.append(file.id)
# Create a file batch
batch = compat_client.vector_stores.file_batches.create(
vector_store_id=vector_store.id,
file_ids=file_ids,
)
assert batch is not None
assert batch.object == "vector_store.file_batch"
assert batch.vector_store_id == vector_store.id
assert batch.status in ["in_progress", "completed"]
assert batch.file_counts.total == len(file_ids)
assert hasattr(batch, "id")
assert hasattr(batch, "created_at")
# Wait for batch processing to complete
max_retries = 30 # 30 seconds max wait
retries = 0
retrieved_batch = None
while retries < max_retries:
retrieved_batch = compat_client.vector_stores.file_batches.retrieve(
vector_store_id=vector_store.id,
batch_id=batch.id,
)
if retrieved_batch.status in ["completed", "failed"]:
break
time.sleep(1)
retries += 1
assert retrieved_batch is not None
assert retrieved_batch.id == batch.id
assert retrieved_batch.vector_store_id == vector_store.id
assert retrieved_batch.object == "vector_store.file_batch"
assert retrieved_batch.file_counts.total == len(file_ids)
assert retrieved_batch.status == "completed" # Should be completed after processing
def test_openai_vector_store_file_batch_list_files(compat_client_with_empty_stores, client_with_models):
"""Test listing files in a vector store file batch."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
compat_client = compat_client_with_empty_stores
# Create a vector store
vector_store = compat_client.vector_stores.create(name="batch_list_test_store")
# Create multiple files
file_ids = []
for i in range(5):
with BytesIO(f"This is batch list test file {i}".encode()) as file_buffer:
file_buffer.name = f"batch_list_test_{i}.txt"
file = compat_client.files.create(file=file_buffer, purpose="assistants")
file_ids.append(file.id)
# Create a file batch
batch = compat_client.vector_stores.file_batches.create(
vector_store_id=vector_store.id,
file_ids=file_ids,
)
# Wait for batch processing to complete
max_retries = 30 # 30 seconds max wait
retries = 0
while retries < max_retries:
retrieved_batch = compat_client.vector_stores.file_batches.retrieve(
vector_store_id=vector_store.id,
batch_id=batch.id,
)
if retrieved_batch.status in ["completed", "failed"]:
break
time.sleep(1)
retries += 1
# List all files in the batch
files_response = compat_client.vector_stores.file_batches.list_files(
vector_store_id=vector_store.id,
batch_id=batch.id,
)
assert files_response is not None
assert files_response.object == "list"
assert hasattr(files_response, "data")
assert len(files_response.data) == len(file_ids)
# Verify all files are in the response
response_file_ids = {file.id for file in files_response.data}
assert response_file_ids == set(file_ids)
# Test pagination with limit
limited_response = compat_client.vector_stores.file_batches.list_files(
vector_store_id=vector_store.id,
batch_id=batch.id,
limit=3,
)
assert len(limited_response.data) == 3
assert limited_response.has_more is True
# Test pagination with after cursor
first_page = compat_client.vector_stores.file_batches.list_files(
vector_store_id=vector_store.id,
batch_id=batch.id,
limit=2,
)
second_page = compat_client.vector_stores.file_batches.list_files(
vector_store_id=vector_store.id,
batch_id=batch.id,
limit=2,
after=first_page.data[-1].id,
)
assert len(first_page.data) == 2
assert len(second_page.data) <= 3 # Should be <= remaining files
# Ensure no overlap between pages
first_page_ids = {file.id for file in first_page.data}
second_page_ids = {file.id for file in second_page.data}
assert first_page_ids.isdisjoint(second_page_ids)
def test_openai_vector_store_file_batch_cancel(compat_client_with_empty_stores, client_with_models):
"""Test cancelling a vector store file batch."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
compat_client = compat_client_with_empty_stores
# Create a vector store
vector_store = compat_client.vector_stores.create(name="batch_cancel_test_store")
# Create multiple files
file_ids = []
for i in range(3):
with BytesIO(f"This is batch cancel test file {i}".encode()) as file_buffer:
file_buffer.name = f"batch_cancel_test_{i}.txt"
file = compat_client.files.create(file=file_buffer, purpose="assistants")
file_ids.append(file.id)
# Create a file batch
batch = compat_client.vector_stores.file_batches.create(
vector_store_id=vector_store.id,
file_ids=file_ids,
)
# Try to cancel the batch (may fail if already completed)
try:
cancelled_batch = compat_client.vector_stores.file_batches.cancel(
vector_store_id=vector_store.id,
batch_id=batch.id,
)
assert cancelled_batch is not None
assert cancelled_batch.id == batch.id
assert cancelled_batch.vector_store_id == vector_store.id
assert cancelled_batch.status == "cancelled"
assert cancelled_batch.object == "vector_store.file_batch"
except Exception as e:
# If cancellation fails because batch is already completed, that's acceptable
if "Cannot cancel" in str(e) or "already completed" in str(e):
pytest.skip(f"Batch completed too quickly to cancel: {e}")
else:
raise
def test_openai_vector_store_file_batch_error_handling(compat_client_with_empty_stores, client_with_models):
"""Test error handling for file batch operations."""
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
compat_client = compat_client_with_empty_stores
# Create a vector store
vector_store = compat_client.vector_stores.create(name="batch_error_test_store")
# Test with invalid file IDs (should handle gracefully)
file_ids = ["invalid_file_id_1", "invalid_file_id_2"]
batch = compat_client.vector_stores.file_batches.create(
vector_store_id=vector_store.id,
file_ids=file_ids,
)
assert batch is not None
assert batch.file_counts.total == len(file_ids)
# Invalid files should be marked as failed
assert batch.file_counts.failed >= 0 # Implementation may vary
# Determine expected errors based on client type
if isinstance(compat_client, LlamaStackAsLibraryClient):
errors = ValueError
else:
errors = (BadRequestError, OpenAIBadRequestError)
# Test retrieving non-existent batch
with pytest.raises(errors): # Should raise an error for non-existent batch
compat_client.vector_stores.file_batches.retrieve(
vector_store_id=vector_store.id,
batch_id="non_existent_batch_id",
)
# Test operations on non-existent vector store
with pytest.raises(errors): # Should raise an error for non-existent vector store
compat_client.vector_stores.file_batches.create(
vector_store_id="non_existent_vector_store",
file_ids=["any_file_id"],
)

View file

@ -6,16 +6,22 @@
import json import json
import time import time
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, patch
import numpy as np import numpy as np
import pytest import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorStoreChunkingStrategyAuto,
VectorStoreFileObject,
)
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
# This test is a unit test for the inline VectoerIO providers. This should only contain # This test is a unit test for the inline VectorIO providers. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in # tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/ # tests/integration/vector_io/
# #
@ -25,6 +31,24 @@ from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREF
# -v -s --tb=short --disable-warnings --asyncio-mode=auto # -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture(autouse=True)
def mock_resume_file_batches(request):
"""Mock the resume functionality to prevent stale file batches from being processed during tests."""
# Skip mocking for tests that specifically test the resume functionality
if any(
test_name in request.node.name
for test_name in ["test_only_in_progress_batches_resumed", "test_file_batch_persistence_across_restarts"]
):
yield
return
with patch(
"llama_stack.providers.utils.memory.openai_vector_store_mixin.OpenAIVectorStoreMixin._resume_incomplete_batches",
new_callable=AsyncMock,
):
yield
async def test_initialize_index(vector_index): async def test_initialize_index(vector_index):
await vector_index.initialize() await vector_index.initialize()
@ -294,3 +318,673 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t
assert loaded_file_info == {} assert loaded_file_info == {}
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == [] assert loaded_contents == []
async def test_create_vector_store_file_batch(vector_io_adapter):
"""Test creating a file batch."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2", "file_3"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to avoid actual processing
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
assert batch.vector_store_id == store_id
assert batch.status == "in_progress"
assert batch.file_counts.total == len(file_ids)
assert batch.file_counts.in_progress == len(file_ids)
assert batch.id in vector_io_adapter.openai_file_batches
async def test_retrieve_vector_store_file_batch(vector_io_adapter):
"""Test retrieving a file batch."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch first
created_batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# Retrieve batch
retrieved_batch = await vector_io_adapter.openai_retrieve_vector_store_file_batch(
batch_id=created_batch.id,
vector_store_id=store_id,
)
assert retrieved_batch.id == created_batch.id
assert retrieved_batch.vector_store_id == store_id
assert retrieved_batch.status == "in_progress"
async def test_cancel_vector_store_file_batch(vector_io_adapter):
"""Test cancelling a file batch."""
store_id = "vs_1234"
file_ids = ["file_1"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock both file attachment and batch processing to prevent automatic completion
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# Cancel batch
cancelled_batch = await vector_io_adapter.openai_cancel_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
)
assert cancelled_batch.status == "cancelled"
async def test_list_files_in_vector_store_file_batch(vector_io_adapter):
"""Test listing files in a batch."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store with files
files = {}
for i, file_id in enumerate(file_ids):
files[file_id] = VectorStoreFileObject(
id=file_id,
object="vector_store.file",
usage_bytes=1000,
created_at=int(time.time()) + i,
vector_store_id=store_id,
status="completed",
chunking_strategy=VectorStoreChunkingStrategyAuto(),
)
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": files,
"file_ids": file_ids,
}
# Mock file loading
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
)
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# List files
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
)
assert len(response.data) == len(file_ids)
assert response.first_id is not None
assert response.last_id is not None
async def test_file_batch_validation_errors(vector_io_adapter):
"""Test file batch validation errors."""
# Test nonexistent vector store
with pytest.raises(VectorStoreNotFoundError):
await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id="nonexistent",
file_ids=["file_1"],
)
# Setup store for remaining tests
store_id = "vs_test"
vector_io_adapter.openai_vector_stores[store_id] = {"id": store_id, "files": {}, "file_ids": []}
# Test nonexistent batch
with pytest.raises(ValueError, match="File batch .* not found"):
await vector_io_adapter.openai_retrieve_vector_store_file_batch(
batch_id="nonexistent_batch",
vector_store_id=store_id,
)
# Test wrong vector store for batch
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=["file_1"],
)
# Create wrong_store so it exists but the batch doesn't belong to it
wrong_store_id = "wrong_store"
vector_io_adapter.openai_vector_stores[wrong_store_id] = {"id": wrong_store_id, "files": {}, "file_ids": []}
with pytest.raises(ValueError, match="does not belong to vector store"):
await vector_io_adapter.openai_retrieve_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=wrong_store_id,
)
async def test_file_batch_pagination(vector_io_adapter):
"""Test file batch pagination."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2", "file_3", "file_4", "file_5"]
# Setup vector store with multiple files
files = {}
for i, file_id in enumerate(file_ids):
files[file_id] = VectorStoreFileObject(
id=file_id,
object="vector_store.file",
usage_bytes=1000,
created_at=int(time.time()) + i,
vector_store_id=store_id,
status="completed",
chunking_strategy=VectorStoreChunkingStrategyAuto(),
)
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": files,
"file_ids": file_ids,
}
# Mock file loading
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
)
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# Test pagination with limit
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
limit=3,
)
assert len(response.data) == 3
assert response.has_more is True
# Test pagination with after cursor
first_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
limit=2,
)
second_page = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
limit=2,
after=first_page.last_id,
)
assert len(first_page.data) == 2
assert len(second_page.data) == 2
# Ensure no overlap between pages
first_page_ids = {file_obj.id for file_obj in first_page.data}
second_page_ids = {file_obj.id for file_obj in second_page.data}
assert first_page_ids.isdisjoint(second_page_ids)
# Verify we got all expected files across both pages (in desc order: file_5, file_4, file_3, file_2, file_1)
all_returned_ids = first_page_ids | second_page_ids
assert all_returned_ids == {"file_2", "file_3", "file_4", "file_5"}
async def test_file_batch_status_filtering(vector_io_adapter):
"""Test file batch status filtering."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2", "file_3"]
# Setup vector store with files having different statuses
files = {}
statuses = ["completed", "in_progress", "completed"]
for i, (file_id, status) in enumerate(zip(file_ids, statuses, strict=False)):
files[file_id] = VectorStoreFileObject(
id=file_id,
object="vector_store.file",
usage_bytes=1000,
created_at=int(time.time()) + i,
vector_store_id=store_id,
status=status,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
)
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": files,
"file_ids": file_ids,
}
# Mock file loading
vector_io_adapter._load_openai_vector_store_file = AsyncMock(
side_effect=lambda vs_id, f_id: files[f_id].model_dump()
)
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# Test filtering by completed status
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
filter="completed",
)
assert len(response.data) == 2 # Only 2 completed files
for file_obj in response.data:
assert file_obj.status == "completed"
# Test filtering by in_progress status
response = await vector_io_adapter.openai_list_files_in_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
filter="in_progress",
)
assert len(response.data) == 1 # Only 1 in_progress file
assert response.data[0].status == "in_progress"
async def test_cancel_completed_batch_fails(vector_io_adapter):
"""Test that cancelling completed batch fails."""
store_id = "vs_1234"
file_ids = ["file_1"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# Manually update status to completed
batch_info = vector_io_adapter.openai_file_batches[batch.id]
batch_info["status"] = "completed"
# Try to cancel - should fail
with pytest.raises(ValueError, match="Cannot cancel batch .* with status completed"):
await vector_io_adapter.openai_cancel_vector_store_file_batch(
batch_id=batch.id,
vector_store_id=store_id,
)
async def test_file_batch_persistence_across_restarts(vector_io_adapter):
"""Test that in-progress file batches are persisted and resumed after restart."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to avoid actual processing
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
batch_id = batch.id
# Verify batch is saved to persistent storage
assert batch_id in vector_io_adapter.openai_file_batches
saved_batch_key = f"openai_vector_stores_file_batches:v3::{batch_id}"
saved_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
assert saved_batch is not None
# Verify the saved batch data contains all necessary information
saved_data = json.loads(saved_batch)
assert saved_data["id"] == batch_id
assert saved_data["status"] == "in_progress"
assert saved_data["file_ids"] == file_ids
# Simulate restart - clear in-memory cache and reload
vector_io_adapter.openai_file_batches.clear()
# Temporarily restore the real initialize_openai_vector_stores method
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
real_method = OpenAIVectorStoreMixin.initialize_openai_vector_stores
await real_method(vector_io_adapter)
# Re-mock the processing method to prevent any resumed batches from processing
vector_io_adapter._process_file_batch_async = AsyncMock()
# Verify batch was restored
assert batch_id in vector_io_adapter.openai_file_batches
restored_batch = vector_io_adapter.openai_file_batches[batch_id]
assert restored_batch["status"] == "in_progress"
assert restored_batch["id"] == batch_id
assert vector_io_adapter.openai_file_batches[batch_id]["file_ids"] == file_ids
async def test_cancelled_batch_persists_in_storage(vector_io_adapter):
"""Test that cancelled batches persist in storage with updated status."""
store_id = "vs_1234"
file_ids = ["file_1", "file_2"]
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to avoid actual processing
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
batch_id = batch.id
# Verify batch is initially saved to persistent storage
saved_batch_key = f"openai_vector_stores_file_batches:v3::{batch_id}"
saved_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
assert saved_batch is not None
# Cancel the batch
cancelled_batch = await vector_io_adapter.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=store_id,
)
# Verify batch status is cancelled
assert cancelled_batch.status == "cancelled"
# Verify batch persists in storage with cancelled status
updated_batch = await vector_io_adapter.kvstore.get(saved_batch_key)
assert updated_batch is not None
batch_data = json.loads(updated_batch)
assert batch_data["status"] == "cancelled"
# Batch should remain in memory cache (matches vector store pattern)
assert batch_id in vector_io_adapter.openai_file_batches
assert vector_io_adapter.openai_file_batches[batch_id]["status"] == "cancelled"
async def test_only_in_progress_batches_resumed(vector_io_adapter):
"""Test that only in-progress batches are resumed for processing, but all batches are persisted."""
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock attach method and batch processing to prevent automatic completion
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create multiple batches
batch1 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_1"]
)
batch2 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_2"]
)
# Complete one batch (should persist with completed status)
batch1_info = vector_io_adapter.openai_file_batches[batch1.id]
batch1_info["status"] = "completed"
await vector_io_adapter._save_openai_vector_store_file_batch(batch1.id, batch1_info)
# Cancel the other batch (should persist with cancelled status)
await vector_io_adapter.openai_cancel_vector_store_file_batch(batch_id=batch2.id, vector_store_id=store_id)
# Create a third batch that stays in progress
batch3 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_3"]
)
# Simulate restart - first clear memory, then reload from persistence
vector_io_adapter.openai_file_batches.clear()
# Mock the processing method BEFORE calling initialize to capture the resume calls
mock_process = AsyncMock()
vector_io_adapter._process_file_batch_async = mock_process
# Temporarily restore the real initialize_openai_vector_stores method
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
real_method = OpenAIVectorStoreMixin.initialize_openai_vector_stores
await real_method(vector_io_adapter)
# All batches should be restored from persistence
assert batch1.id in vector_io_adapter.openai_file_batches # completed, persisted
assert batch2.id in vector_io_adapter.openai_file_batches # cancelled, persisted
assert batch3.id in vector_io_adapter.openai_file_batches # in-progress, restored
# Check their statuses
assert vector_io_adapter.openai_file_batches[batch1.id]["status"] == "completed"
assert vector_io_adapter.openai_file_batches[batch2.id]["status"] == "cancelled"
assert vector_io_adapter.openai_file_batches[batch3.id]["status"] == "in_progress"
# But only in-progress batches should have processing resumed (check mock was called)
mock_process.assert_called()
async def test_cleanup_expired_file_batches(vector_io_adapter):
"""Test that expired file batches are cleaned up properly."""
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Mock processing to prevent automatic completion
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
vector_io_adapter._process_file_batch_async = AsyncMock()
# Create batches with different ages
import time
current_time = int(time.time())
# Create an old expired batch (10 days old)
old_batch_info = {
"id": "batch_old",
"vector_store_id": store_id,
"status": "completed",
"created_at": current_time - (10 * 24 * 60 * 60), # 10 days ago
"expires_at": current_time - (3 * 24 * 60 * 60), # Expired 3 days ago
"file_ids": ["file_1"],
}
# Create a recent valid batch
new_batch_info = {
"id": "batch_new",
"vector_store_id": store_id,
"status": "completed",
"created_at": current_time - (1 * 24 * 60 * 60), # 1 day ago
"expires_at": current_time + (6 * 24 * 60 * 60), # Expires in 6 days
"file_ids": ["file_2"],
}
# Store both batches in persistent storage
await vector_io_adapter._save_openai_vector_store_file_batch("batch_old", old_batch_info)
await vector_io_adapter._save_openai_vector_store_file_batch("batch_new", new_batch_info)
# Add to in-memory cache
vector_io_adapter.openai_file_batches["batch_old"] = old_batch_info
vector_io_adapter.openai_file_batches["batch_new"] = new_batch_info
# Verify both batches exist before cleanup
assert "batch_old" in vector_io_adapter.openai_file_batches
assert "batch_new" in vector_io_adapter.openai_file_batches
# Run cleanup
await vector_io_adapter._cleanup_expired_file_batches()
# Verify expired batch was removed from memory
assert "batch_old" not in vector_io_adapter.openai_file_batches
assert "batch_new" in vector_io_adapter.openai_file_batches
# Verify expired batch was removed from storage
old_batch_key = "openai_vector_stores_file_batches:v3::batch_old"
new_batch_key = "openai_vector_stores_file_batches:v3::batch_new"
old_stored = await vector_io_adapter.kvstore.get(old_batch_key)
new_stored = await vector_io_adapter.kvstore.get(new_batch_key)
assert old_stored is None # Expired batch should be deleted
assert new_stored is not None # Valid batch should remain
async def test_expired_batch_access_error(vector_io_adapter):
"""Test that accessing expired batches returns clear error message."""
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
# Create an expired batch
import time
current_time = int(time.time())
expired_batch_info = {
"id": "batch_expired",
"vector_store_id": store_id,
"status": "completed",
"created_at": current_time - (10 * 24 * 60 * 60), # 10 days ago
"expires_at": current_time - (3 * 24 * 60 * 60), # Expired 3 days ago
"file_ids": ["file_1"],
}
# Add to in-memory cache (simulating it was loaded before expiration)
vector_io_adapter.openai_file_batches["batch_expired"] = expired_batch_info
# Try to access expired batch
with pytest.raises(ValueError, match="File batch batch_expired has expired after 7 days from creation"):
vector_io_adapter._get_and_validate_batch("batch_expired", store_id)
async def test_max_concurrent_files_per_batch(vector_io_adapter):
"""Test that file batch processing respects MAX_CONCURRENT_FILES_PER_BATCH limit."""
import asyncio
store_id = "vs_1234"
# Setup vector store
vector_io_adapter.openai_vector_stores[store_id] = {
"id": store_id,
"name": "Test Store",
"files": {},
"file_ids": [],
}
active_files = 0
async def mock_attach_file_with_delay(vector_store_id: str, file_id: str, **kwargs):
"""Mock that tracks concurrency and blocks indefinitely to test concurrency limit."""
nonlocal active_files
active_files += 1
# Block indefinitely to test concurrency limit
await asyncio.sleep(float("inf"))
# Replace the attachment method
vector_io_adapter.openai_attach_file_to_vector_store = mock_attach_file_with_delay
# Create a batch with more files than the concurrency limit
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
)
# Give time for the semaphore logic to start processing files
await asyncio.sleep(0.2)
# Verify that only MAX_CONCURRENT_FILES_PER_BATCH files are processing concurrently
# The semaphore in _process_files_with_concurrency should limit this
from llama_stack.providers.utils.memory.openai_vector_store_mixin import MAX_CONCURRENT_FILES_PER_BATCH
assert active_files == MAX_CONCURRENT_FILES_PER_BATCH, (
f"Expected {MAX_CONCURRENT_FILES_PER_BATCH} active files, got {active_files}"
)
# Verify batch is in progress
assert batch.status == "in_progress"
assert batch.file_counts.total == 8
assert batch.file_counts.in_progress == 8