feat(api): Add vector store file batches api

This commit is contained in:
Swapna Lekkala 2025-10-01 14:44:48 -07:00
parent 7ec7e0c1ac
commit c7d50d4496
11 changed files with 1229 additions and 23 deletions

View file

@ -245,3 +245,65 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
vector_store_id=vector_store_id,
file_id=file_id,
)
async def openai_create_vector_store_file_batch(
self,
vector_store_id: str,
file_ids: list[str],
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id,
file_ids=file_ids,
attributes=attributes,
chunking_strategy=chunking_strategy,
)
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
after: str | None = None,
before: str | None = None,
filter: str | None = None,
limit: int | None = 20,
order: str | None = "desc",
):
await self.assert_action_allowed("read", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
after=after,
before=before,
filter=filter,
limit=limit,
order=order,
)
async def openai_cancel_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
):
await self.assert_action_allowed("update", "vector_db", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id,
vector_store_id=vector_store_id,
)

View file

@ -206,6 +206,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)

View file

@ -415,6 +415,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:

View file

@ -166,6 +166,7 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
self.openai_vector_stores = await self._load_openai_vector_stores()
self.openai_file_batches: dict[str, dict[str, Any]] = {}
async def shutdown(self) -> None:
pass

View file

@ -317,6 +317,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:

View file

@ -353,6 +353,7 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:

View file

@ -170,6 +170,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.vector_db_store = None
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self._qdrant_lock = asyncio.Lock()
async def initialize(self) -> None:

View file

@ -292,6 +292,7 @@ class WeaviateVectorIOAdapter(
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.WeaviateClient:

View file

@ -56,6 +56,7 @@ VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX = f"openai_vector_stores_file_batches:{VERSION}::"
class OpenAIVectorStoreMixin(ABC):
@ -67,9 +68,12 @@ class OpenAIVectorStoreMixin(ABC):
# These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]]
openai_file_batches: dict[str, dict[str, Any]]
files_api: Files | None
# KV store for persisting OpenAI vector store metadata
kvstore: KVStore | None
# Track last cleanup time to throttle cleanup operations
_last_cleanup_time: int
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
@ -159,9 +163,85 @@ class OpenAIVectorStoreMixin(ABC):
for idx in range(len(raw_items)):
await self.kvstore.delete(f"{contents_prefix}{idx}")
async def _save_openai_vector_store_file_batch(self, batch_id: str, batch_info: dict[str, Any]) -> None:
"""Save file batch metadata to persistent storage."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{batch_id}"
await self.kvstore.set(key=key, value=json.dumps(batch_info))
# update in-memory cache
self.openai_file_batches[batch_id] = batch_info
async def _load_openai_vector_store_file_batches(self) -> dict[str, dict[str, Any]]:
"""Load all file batch metadata from persistent storage."""
assert self.kvstore
start_key = OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
batches: dict[str, dict[str, Any]] = {}
for item in stored_data:
info = json.loads(item)
batches[info["id"]] = info
return batches
async def _delete_openai_vector_store_file_batch(self, batch_id: str) -> None:
"""Delete file batch metadata from persistent storage and in-memory cache."""
assert self.kvstore
key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{batch_id}"
await self.kvstore.delete(key)
# remove from in-memory cache
self.openai_file_batches.pop(batch_id, None)
async def _cleanup_expired_file_batches(self) -> None:
"""Clean up expired file batches from persistent storage."""
assert self.kvstore
start_key = OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
current_time = int(time.time())
expired_count = 0
for item in stored_data:
info = json.loads(item)
expires_at = info.get("expires_at")
if expires_at and current_time > expires_at:
logger.info(f"Cleaning up expired file batch: {info['id']}")
await self.kvstore.delete(f"{OPENAI_VECTOR_STORES_FILE_BATCHES_PREFIX}{info['id']}")
# Remove from in-memory cache if present
self.openai_file_batches.pop(info["id"], None)
expired_count += 1
if expired_count > 0:
logger.info(f"Cleaned up {expired_count} expired file batches")
async def _cleanup_expired_file_batches_if_needed(self) -> None:
"""Run cleanup if enough time has passed since the last cleanup."""
current_time = int(time.time())
cleanup_interval = 24 * 60 * 60 # 1 day in seconds
# Check if enough time has passed since last cleanup
if current_time - self._last_cleanup_time >= cleanup_interval:
logger.info("Running throttled cleanup of expired file batches")
await self._cleanup_expired_file_batches()
self._last_cleanup_time = current_time
async def _resume_incomplete_batches(self) -> None:
"""Resume processing of incomplete file batches after server restart."""
for batch_id, batch_info in self.openai_file_batches.items():
if batch_info["status"] == "in_progress":
logger.info(f"Resuming incomplete file batch: {batch_id}")
# Restart the background processing task
asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores into the in-memory cache."""
"""Load existing OpenAI vector stores and file batches into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores()
self.openai_file_batches = await self._load_openai_vector_store_file_batches()
# Resume any incomplete file batches
await self._resume_incomplete_batches()
# Initialize last cleanup time
self._last_cleanup_time = 0
@abstractmethod
async def delete_chunks(self, store_id: str, chunks_for_deletion: list[ChunkForDeletion]) -> None:
@ -615,7 +695,6 @@ class OpenAIVectorStoreMixin(ABC):
chunk_overlap_tokens,
attributes,
)
if not chunks:
vector_store_file_object.status = "failed"
vector_store_file_object.last_error = VectorStoreFileLastError(
@ -828,7 +907,165 @@ class OpenAIVectorStoreMixin(ABC):
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileBatchObject:
"""Create a vector store file batch."""
raise NotImplementedError("openai_create_vector_store_file_batch is not implemented yet")
if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id)
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
created_at = int(time.time())
batch_id = f"batch_{uuid.uuid4()}"
# File batches expire after 7 days
expires_at = created_at + (7 * 24 * 60 * 60)
# Initialize batch file counts - all files start as in_progress
file_counts = VectorStoreFileCounts(
completed=0,
cancelled=0,
failed=0,
in_progress=len(file_ids),
total=len(file_ids),
)
# Create batch object immediately with in_progress status
batch_object = VectorStoreFileBatchObject(
id=batch_id,
created_at=created_at,
vector_store_id=vector_store_id,
status="in_progress",
file_counts=file_counts,
)
batch_info = {
**batch_object.model_dump(),
"file_ids": file_ids,
"attributes": attributes,
"chunking_strategy": chunking_strategy.model_dump(),
"expires_at": expires_at,
}
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
# Start background processing of files
asyncio.create_task(self._process_file_batch_async(batch_id, batch_info))
# Run cleanup if needed (throttled to once every 7 days)
asyncio.create_task(self._cleanup_expired_file_batches_if_needed())
return batch_object
async def _process_file_batch_async(
self,
batch_id: str,
batch_info: dict[str, Any],
) -> None:
"""Process files in a batch asynchronously in the background."""
file_ids = batch_info["file_ids"]
attributes = batch_info["attributes"]
chunking_strategy = batch_info["chunking_strategy"]
vector_store_id = batch_info["vector_store_id"]
for file_id in file_ids:
try:
chunking_strategy_obj = (
VectorStoreChunkingStrategyStatic(**chunking_strategy)
if chunking_strategy.get("type") == "static"
else VectorStoreChunkingStrategyAuto(**chunking_strategy)
)
await self.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id,
file_id=file_id,
attributes=attributes,
chunking_strategy=chunking_strategy_obj,
)
# Update counts atomically
batch_info["file_counts"]["completed"] += 1
batch_info["file_counts"]["in_progress"] -= 1
except Exception as e:
logger.error(f"Failed to process file {file_id} in batch {batch_id}: {e}")
batch_info["file_counts"]["failed"] += 1
batch_info["file_counts"]["in_progress"] -= 1
# Update final status when all files are processed
if batch_info["file_counts"]["failed"] == 0:
batch_info["status"] = "completed"
elif batch_info["file_counts"]["completed"] == 0:
batch_info["status"] = "failed"
else:
batch_info["status"] = "completed" # Partial success counts as completed
# Save final batch status to persistent storage (keep completed batches like vector stores)
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
logger.info(f"File batch {batch_id} processing completed with status: {batch_info['status']}")
def _get_and_validate_batch(self, batch_id: str, vector_store_id: str) -> dict[str, Any]:
"""Get and validate batch exists and belongs to vector store."""
if vector_store_id not in self.openai_vector_stores:
raise VectorStoreNotFoundError(vector_store_id)
if batch_id not in self.openai_file_batches:
raise ValueError(f"File batch {batch_id} not found")
batch_info = self.openai_file_batches[batch_id]
# Check if batch has expired (read-only check)
expires_at = batch_info.get("expires_at")
if expires_at:
current_time = int(time.time())
if current_time > expires_at:
raise ValueError(f"File batch {batch_id} has expired after 7 days from creation")
if batch_info["vector_store_id"] != vector_store_id:
raise ValueError(f"File batch {batch_id} does not belong to vector store {vector_store_id}")
return batch_info
def _paginate_objects(
self,
objects: list[Any],
limit: int | None = 20,
after: str | None = None,
before: str | None = None,
) -> tuple[list[Any], bool, str | None, str | None]:
"""Apply pagination to a list of objects with id fields."""
limit = min(limit or 20, 100) # Cap at 100 as per OpenAI
# Find start index
start_idx = 0
if after:
for i, obj in enumerate(objects):
if obj.id == after:
start_idx = i + 1
break
# Find end index
end_idx = start_idx + limit
if before:
for i, obj in enumerate(objects[start_idx:], start_idx):
if obj.id == before:
end_idx = i
break
# Apply pagination
paginated_objects = objects[start_idx:end_idx]
# Determine pagination info
has_more = end_idx < len(objects)
first_id = paginated_objects[0].id if paginated_objects else None
last_id = paginated_objects[-1].id if paginated_objects else None
return paginated_objects, has_more, first_id, last_id
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Retrieve a vector store file batch."""
batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
# Convert dict back to Pydantic model for API response
return VectorStoreFileBatchObject(**batch_info)
async def openai_list_files_in_vector_store_file_batch(
self,
@ -841,15 +1078,39 @@ class OpenAIVectorStoreMixin(ABC):
order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse:
"""Returns a list of vector store files in a batch."""
raise NotImplementedError("openai_list_files_in_vector_store_file_batch is not implemented yet")
batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
batch_file_ids = batch_info["file_ids"]
async def openai_retrieve_vector_store_file_batch(
self,
batch_id: str,
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Retrieve a vector store file batch."""
raise NotImplementedError("openai_retrieve_vector_store_file_batch is not implemented yet")
# Load file objects for files in this batch
batch_file_objects = []
for file_id in batch_file_ids:
try:
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
file_object = VectorStoreFileObject(**file_info)
# Apply status filter if provided
if filter and file_object.status != filter:
continue
batch_file_objects.append(file_object)
except Exception as e:
logger.warning(f"Could not load file {file_id} from batch {batch_id}: {e}")
continue
# Sort by created_at
reverse_order = order == "desc"
batch_file_objects.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply pagination using helper
paginated_files, has_more, first_id, last_id = self._paginate_objects(batch_file_objects, limit, after, before)
return VectorStoreFilesListInBatchResponse(
data=paginated_files,
first_id=first_id,
last_id=last_id,
has_more=has_more,
)
async def openai_cancel_vector_store_file_batch(
self,
@ -857,4 +1118,19 @@ class OpenAIVectorStoreMixin(ABC):
vector_store_id: str,
) -> VectorStoreFileBatchObject:
"""Cancel a vector store file batch."""
raise NotImplementedError("openai_cancel_vector_store_file_batch is not implemented yet")
batch_info = self._get_and_validate_batch(batch_id, vector_store_id)
# Only allow cancellation if batch is in progress
if batch_info["status"] not in ["in_progress"]:
raise ValueError(f"Cannot cancel batch {batch_id} with status {batch_info['status']}")
# Update batch with cancelled status
batch_info["status"] = "cancelled"
# Save cancelled batch status to persistent storage (keep cancelled batches like vector stores)
await self._save_openai_vector_store_file_batch(batch_id, batch_info)
# Create updated batch object for API response
updated_batch = VectorStoreFileBatchObject(**batch_info)
return updated_batch