diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index d480ff592..9896b36cd 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -14591,7 +14591,8 @@
"OpenAIFilePurpose": {
"type": "string",
"enum": [
- "assistants"
+ "assistants",
+ "batch"
],
"title": "OpenAIFilePurpose",
"description": "Valid purpose values for OpenAI Files API."
@@ -14668,7 +14669,8 @@
"purpose": {
"type": "string",
"enum": [
- "assistants"
+ "assistants",
+ "batch"
],
"description": "The intended purpose of the file"
}
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index 9c0fba554..15d491a65 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -10804,6 +10804,7 @@ components:
type: string
enum:
- assistants
+ - batch
title: OpenAIFilePurpose
description: >-
Valid purpose values for OpenAI Files API.
@@ -10872,6 +10873,7 @@ components:
type: string
enum:
- assistants
+ - batch
description: The intended purpose of the file
additionalProperties: false
required:
diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md
new file mode 100644
index 000000000..d2405ecf7
--- /dev/null
+++ b/docs/source/providers/batches/index.md
@@ -0,0 +1,13 @@
+# Batches
+
+## Overview
+
+This section contains documentation for all available providers for the **batches** API.
+
+## Providers
+
+```{toctree}
+:maxdepth: 1
+
+inline_reference
+```
diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md
new file mode 100644
index 000000000..a58e5124d
--- /dev/null
+++ b/docs/source/providers/batches/inline_reference.md
@@ -0,0 +1,23 @@
+# inline::reference
+
+## Description
+
+Reference implementation of batches API with KVStore persistence.
+
+## Configuration
+
+| Field | Type | Required | Default | Description |
+|-------|------|----------|---------|-------------|
+| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
+| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
+| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. |
+
+## Sample Configuration
+
+```yaml
+kvstore:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
+
+```
+
diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py
new file mode 100644
index 000000000..d3efe3dba
--- /dev/null
+++ b/llama_stack/apis/batches/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from .batches import Batches, BatchObject, CreateBatchRequest, ListBatchesResponse
+
+__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesResponse"]
diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py
new file mode 100644
index 000000000..72742d4fa
--- /dev/null
+++ b/llama_stack/apis/batches/batches.py
@@ -0,0 +1,92 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import Any, Literal, Protocol, runtime_checkable
+
+from pydantic import BaseModel, Field
+
+from llama_stack.schema_utils import json_schema_type, webmethod
+
+try:
+ from openai.types import Batch as BatchObject
+except ImportError as e:
+ raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
+
+
+@json_schema_type
+class CreateBatchRequest(BaseModel):
+ """Request to create a new batch."""
+
+ input_file_id: str = Field(..., description="The ID of an uploaded file that contains requests for the new batch")
+ endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch")
+ completion_window: str = Field(..., description="The time window within which the batch should be processed")
+ metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata for the batch")
+
+
+@json_schema_type
+class ListBatchesResponse(BaseModel):
+ """Response containing a list of batch objects."""
+
+ object: Literal["list"] = "list"
+ data: list[BatchObject] = Field(..., description="List of batch objects")
+ first_id: str | None = Field(default=None, description="ID of the first batch in the list")
+ last_id: str | None = Field(default=None, description="ID of the last batch in the list")
+ has_more: bool = Field(default=False, description="Whether there are more batches available")
+
+
+@runtime_checkable
+class Batches(Protocol):
+ """Protocol for batch processing API operations."""
+
+ @webmethod(route="/openai/v1/batches", method="POST")
+ async def create_batch(
+ self,
+ input_file_id: str,
+ endpoint: str,
+ completion_window: str,
+ metadata: dict[str, str] | None = None,
+ ) -> BatchObject:
+ """Create a new batch for processing multiple API requests.
+
+ :param input_file_id: The ID of an uploaded file containing requests for the batch.
+ :param endpoint: The endpoint to be used for all requests in the batch.
+ :param completion_window: The time window within which the batch should be processed.
+ :param metadata: Optional metadata for the batch.
+ :returns: The created batch object.
+ """
+ ...
+
+ @webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
+ async def retrieve_batch(self, batch_id: str) -> BatchObject:
+ """Retrieve information about a specific batch.
+
+ :param batch_id: The ID of the batch to retrieve.
+ :returns: The batch object.
+ """
+ ...
+
+ @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
+ async def cancel_batch(self, batch_id: str) -> BatchObject:
+ """Cancel a batch that is in progress.
+
+ :param batch_id: The ID of the batch to cancel.
+ :returns: The updated batch object.
+ """
+ ...
+
+ @webmethod(route="/openai/v1/batches", method="GET")
+ async def list_batches(
+ self,
+ after: str | None = None,
+ limit: int = 20,
+ ) -> ListBatchesResponse:
+ """List all batches for the current user.
+
+ :param after: A cursor for pagination; returns batches after this batch ID.
+ :param limit: Number of batches to return (default 20, max 100).
+ :returns: A list of batch objects.
+ """
+ ...
diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py
index 95d6ac18e..c47c99f8d 100644
--- a/llama_stack/apis/common/errors.py
+++ b/llama_stack/apis/common/errors.py
@@ -62,3 +62,10 @@ class SessionNotFoundError(ValueError):
def __init__(self, session_name: str) -> None:
message = f"Session '{session_name}' not found or access denied."
super().__init__(message)
+
+
+class ConflictError(ValueError):
+ """raised when an operation cannot be performed due to a conflict with the current state"""
+
+ def __init__(self, message: str) -> None:
+ super().__init__(message)
diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py
index cabe46a2f..87fc95917 100644
--- a/llama_stack/apis/datatypes.py
+++ b/llama_stack/apis/datatypes.py
@@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar inference: Text generation, chat completions, and embeddings
:cvar safety: Content moderation and safety shields
:cvar agents: Agent orchestration and execution
+ :cvar batches: Batch processing for asynchronous API requests
:cvar vector_io: Vector database operations and queries
:cvar datasetio: Dataset input/output operations
:cvar scoring: Model output evaluation and scoring
@@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
inference = "inference"
safety = "safety"
agents = "agents"
+ batches = "batches"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"
diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py
index ba8701e23..a1b9dd4dc 100644
--- a/llama_stack/apis/files/files.py
+++ b/llama_stack/apis/files/files.py
@@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum):
"""
ASSISTANTS = "assistants"
+ BATCH = "batch"
# TODO: Add other purposes as needed
diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py
index 70c78fb01..7ac98dac8 100644
--- a/llama_stack/core/resolver.py
+++ b/llama_stack/core/resolver.py
@@ -8,6 +8,7 @@ import inspect
from typing import Any
from llama_stack.apis.agents import Agents
+from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
+ Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py
index fe5cc68d7..f5ef40275 100644
--- a/llama_stack/core/server/server.py
+++ b/llama_stack/core/server/server.py
@@ -31,6 +31,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
+from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError
@@ -127,6 +128,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
]
},
)
+ elif isinstance(exc, ConflictError):
+ return HTTPException(status_code=409, detail=str(exc))
+ elif isinstance(exc, ResourceNotFoundError):
+ return HTTPException(status_code=404, detail=str(exc))
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py
new file mode 100644
index 000000000..756f351d8
--- /dev/null
+++ b/llama_stack/providers/inline/batches/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py
new file mode 100644
index 000000000..a8ae92eb2
--- /dev/null
+++ b/llama_stack/providers/inline/batches/reference/__init__.py
@@ -0,0 +1,36 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from typing import Any
+
+from llama_stack.apis.files import Files
+from llama_stack.apis.inference import Inference
+from llama_stack.apis.models import Models
+from llama_stack.core.datatypes import AccessRule, Api
+from llama_stack.providers.utils.kvstore import kvstore_impl
+
+from .batches import ReferenceBatchesImpl
+from .config import ReferenceBatchesImplConfig
+
+__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
+
+
+async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
+ kvstore = await kvstore_impl(config.kvstore)
+ inference_api: Inference | None = deps.get(Api.inference)
+ files_api: Files | None = deps.get(Api.files)
+ models_api: Models | None = deps.get(Api.models)
+
+ if inference_api is None:
+ raise ValueError("Inference API is required but not provided in dependencies")
+ if files_api is None:
+ raise ValueError("Files API is required but not provided in dependencies")
+ if models_api is None:
+ raise ValueError("Models API is required but not provided in dependencies")
+
+ impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
+ await impl.initialize()
+ return impl
diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py
new file mode 100644
index 000000000..6e99a00a1
--- /dev/null
+++ b/llama_stack/providers/inline/batches/reference/batches.py
@@ -0,0 +1,553 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+import asyncio
+import itertools
+import json
+import time
+import uuid
+from io import BytesIO
+from typing import Any
+
+from openai.types.batch import BatchError, Errors
+from pydantic import BaseModel
+
+from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
+from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
+from llama_stack.apis.files import Files, OpenAIFilePurpose
+from llama_stack.apis.inference import Inference
+from llama_stack.apis.models import Models
+from llama_stack.log import get_logger
+from llama_stack.providers.utils.kvstore import KVStore
+
+from .config import ReferenceBatchesImplConfig
+
+BATCH_PREFIX = "batch:"
+
+logger = get_logger(__name__)
+
+
+class AsyncBytesIO:
+ """
+ Async-compatible BytesIO wrapper to allow async file-like operations.
+
+ We use this when uploading files to the Files API, as it expects an
+ async file-like object.
+ """
+
+ def __init__(self, data: bytes):
+ self._buffer = BytesIO(data)
+
+ async def read(self, n=-1):
+ return self._buffer.read(n)
+
+ async def seek(self, pos, whence=0):
+ return self._buffer.seek(pos, whence)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self._buffer.close()
+
+ def __getattr__(self, name):
+ return getattr(self._buffer, name)
+
+
+class BatchRequest(BaseModel):
+ line_num: int
+ custom_id: str
+ method: str
+ url: str
+ body: dict[str, Any]
+
+
+class ReferenceBatchesImpl(Batches):
+ """Reference implementation of the Batches API.
+
+ This implementation processes batch files by making individual requests
+ to the inference API and generates output files with results.
+ """
+
+ def __init__(
+ self,
+ config: ReferenceBatchesImplConfig,
+ inference_api: Inference,
+ files_api: Files,
+ models_api: Models,
+ kvstore: KVStore,
+ ) -> None:
+ self.config = config
+ self.kvstore = kvstore
+ self.inference_api = inference_api
+ self.files_api = files_api
+ self.models_api = models_api
+ self._processing_tasks: dict[str, asyncio.Task] = {}
+ self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
+ self._update_batch_lock = asyncio.Lock()
+
+ # this is to allow tests to disable background processing
+ self.process_batches = True
+
+ async def initialize(self) -> None:
+ # TODO: start background processing of existing tasks
+ pass
+
+ async def shutdown(self) -> None:
+ """Shutdown the batches provider."""
+ if self._processing_tasks:
+ # don't cancel tasks - just let them stop naturally on shutdown
+ # cancelling would mark batches as "cancelled" in the database
+ logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
+
+ # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
+ async def create_batch(
+ self,
+ input_file_id: str,
+ endpoint: str,
+ completion_window: str,
+ metadata: dict[str, str] | None = None,
+ ) -> BatchObject:
+ """
+ Create a new batch for processing multiple API requests.
+
+ Error handling by levels -
+ 0. Input param handling, results in 40x errors before processing, e.g.
+ - Wrong completion_window
+ - Invalid metadata types
+ - Unknown endpoint
+ -> no batch created
+ 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
+ - input_file_id missing
+ - invalid json in file
+ - missing custom_id, method, url, body
+ - invalid model
+ - streaming
+ -> batch created, validation sends to failed status
+ 2. Processing errors, result in error_file_id entries, e.g.
+ - Any error returned from inference endpoint
+ -> batch created, goes to completed status
+ """
+
+ # TODO: set expiration time for garbage collection
+
+ if endpoint not in ["/v1/chat/completions"]:
+ raise ValueError(
+ f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
+ )
+
+ if completion_window != "24h":
+ raise ValueError(
+ f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
+ )
+
+ batch_id = f"batch_{uuid.uuid4().hex[:16]}"
+ current_time = int(time.time())
+
+ batch = BatchObject(
+ id=batch_id,
+ object="batch",
+ endpoint=endpoint,
+ input_file_id=input_file_id,
+ completion_window=completion_window,
+ status="validating",
+ created_at=current_time,
+ metadata=metadata,
+ )
+
+ await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
+
+ if self.process_batches:
+ task = asyncio.create_task(self._process_batch(batch_id))
+ self._processing_tasks[batch_id] = task
+
+ return batch
+
+ async def cancel_batch(self, batch_id: str) -> BatchObject:
+ """Cancel a batch that is in progress."""
+ batch = await self.retrieve_batch(batch_id)
+
+ if batch.status in ["cancelled", "cancelling"]:
+ return batch
+
+ if batch.status in ["completed", "failed", "expired"]:
+ raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
+
+ await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
+
+ if batch_id in self._processing_tasks:
+ self._processing_tasks[batch_id].cancel()
+ # note: task removal and status="cancelled" handled in finally block of _process_batch
+
+ return await self.retrieve_batch(batch_id)
+
+ async def list_batches(
+ self,
+ after: str | None = None,
+ limit: int = 20,
+ ) -> ListBatchesResponse:
+ """
+ List all batches, eventually only for the current user.
+
+ With no notion of user, we return all batches.
+ """
+ batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
+
+ batches = []
+ for batch_data in batch_values:
+ if batch_data:
+ batches.append(BatchObject.model_validate_json(batch_data))
+
+ batches.sort(key=lambda b: b.created_at, reverse=True)
+
+ start_idx = 0
+ if after:
+ for i, batch in enumerate(batches):
+ if batch.id == after:
+ start_idx = i + 1
+ break
+
+ page_batches = batches[start_idx : start_idx + limit]
+ has_more = (start_idx + limit) < len(batches)
+
+ first_id = page_batches[0].id if page_batches else None
+ last_id = page_batches[-1].id if page_batches else None
+
+ return ListBatchesResponse(
+ data=page_batches,
+ first_id=first_id,
+ last_id=last_id,
+ has_more=has_more,
+ )
+
+ async def retrieve_batch(self, batch_id: str) -> BatchObject:
+ """Retrieve information about a specific batch."""
+ batch_data = await self.kvstore.get(f"batch:{batch_id}")
+ if not batch_data:
+ raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
+
+ return BatchObject.model_validate_json(batch_data)
+
+ async def _update_batch(self, batch_id: str, **updates) -> None:
+ """Update batch fields in kvstore."""
+ async with self._update_batch_lock:
+ try:
+ batch = await self.retrieve_batch(batch_id)
+
+ # batch processing is async. once cancelling, only allow "cancelled" status updates
+ if batch.status == "cancelling" and updates.get("status") != "cancelled":
+ logger.info(
+ f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
+ )
+ return
+
+ if "errors" in updates:
+ updates["errors"] = updates["errors"].model_dump()
+
+ batch_dict = batch.model_dump()
+ batch_dict.update(updates)
+
+ await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
+ except Exception as e:
+ logger.error(f"Failed to update batch {batch_id}: {e}")
+
+ async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
+ """
+ Read & validate input, return errors and valid input.
+
+ Validation of
+ - input_file_id existance
+ - valid json
+ - custom_id, method, url, body presence and valid
+ - no streaming
+ """
+ requests: list[BatchRequest] = []
+ errors: list[BatchError] = []
+ try:
+ await self.files_api.openai_retrieve_file(batch.input_file_id)
+ except Exception:
+ errors.append(
+ BatchError(
+ code="invalid_request",
+ line=None,
+ message=f"Cannot find file {batch.input_file_id}.",
+ param="input_file_id",
+ )
+ )
+ return errors, requests
+
+ # TODO(SECURITY): do something about large files
+ file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
+ file_content = file_content_response.body.decode("utf-8")
+ for line_num, line in enumerate(file_content.strip().split("\n"), 1):
+ if line.strip(): # skip empty lines
+ try:
+ request = json.loads(line)
+
+ if not isinstance(request, dict):
+ errors.append(
+ BatchError(
+ code="invalid_request",
+ line=line_num,
+ message="Each line must be a JSON dictionary object",
+ )
+ )
+ continue
+
+ valid = True
+
+ for param, expected_type, type_string in [
+ ("custom_id", str, "string"),
+ ("method", str, "string"),
+ ("url", str, "string"),
+ ("body", dict, "JSON dictionary object"),
+ ]:
+ if param not in request:
+ errors.append(
+ BatchError(
+ code="missing_required_parameter",
+ line=line_num,
+ message=f"Missing required parameter: {param}",
+ param=param,
+ )
+ )
+ valid = False
+ elif not isinstance(request[param], expected_type):
+ param_name = "URL" if param == "url" else param.capitalize()
+ errors.append(
+ BatchError(
+ code="invalid_request",
+ line=line_num,
+ message=f"{param_name} must be a {type_string}",
+ param=param,
+ )
+ )
+ valid = False
+
+ if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
+ errors.append(
+ BatchError(
+ code="invalid_url",
+ line=line_num,
+ message="URL provided for this request does not match the batch endpoint",
+ param="url",
+ )
+ )
+ valid = False
+
+ if (body := request.get("body")) and isinstance(body, dict):
+ if body.get("stream", False):
+ errors.append(
+ BatchError(
+ code="streaming_unsupported",
+ line=line_num,
+ message="Streaming is not supported in batch processing",
+ param="body.stream",
+ )
+ )
+ valid = False
+
+ for param, expected_type, type_string in [
+ ("model", str, "a string"),
+ # messages is specific to /v1/chat/completions
+ # we could skip validating messages here and let inference fail. however,
+ # that would be a very expensive way to find out messages is wrong.
+ ("messages", list, "an array"), # TODO: allow messages to be a string?
+ ]:
+ if param not in body:
+ errors.append(
+ BatchError(
+ code="invalid_request",
+ line=line_num,
+ message=f"{param.capitalize()} parameter is required",
+ param=f"body.{param}",
+ )
+ )
+ valid = False
+ elif not isinstance(body[param], expected_type):
+ errors.append(
+ BatchError(
+ code="invalid_request",
+ line=line_num,
+ message=f"{param.capitalize()} must be {type_string}",
+ param=f"body.{param}",
+ )
+ )
+ valid = False
+
+ if "model" in body and isinstance(body["model"], str):
+ try:
+ await self.models_api.get_model(body["model"])
+ except Exception:
+ errors.append(
+ BatchError(
+ code="model_not_found",
+ line=line_num,
+ message=f"Model '{body['model']}' does not exist or is not supported",
+ param="body.model",
+ )
+ )
+ valid = False
+
+ if valid:
+ assert isinstance(url, str), "URL must be a string" # for mypy
+ assert isinstance(body, dict), "Body must be a dictionary" # for mypy
+ requests.append(
+ BatchRequest(
+ line_num=line_num,
+ url=url,
+ method=request["method"],
+ custom_id=request["custom_id"],
+ body=body,
+ ),
+ )
+ except json.JSONDecodeError:
+ errors.append(
+ BatchError(
+ code="invalid_json_line",
+ line=line_num,
+ message="This line is not parseable as valid JSON.",
+ )
+ )
+
+ return errors, requests
+
+ async def _process_batch(self, batch_id: str) -> None:
+ """Background task to process a batch of requests."""
+ try:
+ logger.info(f"Starting batch processing for {batch_id}")
+ async with self._batch_semaphore: # semaphore to limit concurrency
+ logger.info(f"Acquired semaphore for batch {batch_id}")
+ await self._process_batch_impl(batch_id)
+ except asyncio.CancelledError:
+ logger.info(f"Batch processing cancelled for {batch_id}")
+ await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
+ except Exception as e:
+ logger.error(f"Batch processing failed for {batch_id}: {e}")
+ await self._update_batch(
+ batch_id,
+ status="failed",
+ failed_at=int(time.time()),
+ errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
+ )
+ finally:
+ self._processing_tasks.pop(batch_id, None)
+
+ async def _process_batch_impl(self, batch_id: str) -> None:
+ """Implementation of batch processing logic."""
+ errors: list[BatchError] = []
+ batch = await self.retrieve_batch(batch_id)
+
+ errors, requests = await self._validate_input(batch)
+ if errors:
+ await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
+ logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
+ return
+
+ logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
+
+ total_requests = len(requests)
+ await self._update_batch(
+ batch_id,
+ status="in_progress",
+ request_counts={"total": total_requests, "completed": 0, "failed": 0},
+ )
+
+ error_results = []
+ success_results = []
+ completed_count = 0
+ failed_count = 0
+
+ for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
+ # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
+ async with asyncio.TaskGroup() as tg:
+ chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
+
+ chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
+
+ for result in chunk_results:
+ if isinstance(result, dict) and result.get("error") is not None: # error response from inference
+ failed_count += 1
+ error_results.append(result)
+ elif isinstance(result, dict) and result.get("response") is not None: # successful inference
+ completed_count += 1
+ success_results.append(result)
+ else: # unexpected result
+ failed_count += 1
+ errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
+
+ await self._update_batch(
+ batch_id,
+ request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
+ )
+
+ if errors:
+ await self._update_batch(
+ batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
+ )
+ return
+
+ try:
+ output_file_id = await self._create_output_file(batch_id, success_results, "success")
+ await self._update_batch(batch_id, output_file_id=output_file_id)
+
+ error_file_id = await self._create_output_file(batch_id, error_results, "error")
+ await self._update_batch(batch_id, error_file_id=error_file_id)
+
+ await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
+
+ logger.info(
+ f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
+ )
+ except Exception as e:
+ # note: errors is empty at this point, so we don't lose anything by ignoring it
+ await self._update_batch(
+ batch_id,
+ status="failed",
+ failed_at=int(time.time()),
+ errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
+ )
+
+ async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
+ """Process a single request from the batch."""
+ request_id = f"batch_req_{batch_id}_{request.line_num}"
+
+ try:
+ # TODO(SECURITY): review body for security issues
+ chat_response = await self.inference_api.openai_chat_completion(**request.body)
+
+ # this is for mypy, we don't allow streaming so we'll get the right type
+ assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
+ return {
+ "id": request_id,
+ "custom_id": request.custom_id,
+ "response": {
+ "status_code": 200,
+ "request_id": request_id, # TODO: should this be different?
+ "body": chat_response.model_dump_json(),
+ },
+ }
+ except Exception as e:
+ logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
+ return {
+ "id": request_id,
+ "custom_id": request.custom_id,
+ "error": {"type": "request_failed", "message": str(e)},
+ }
+
+ async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
+ """
+ Create an output file with batch results.
+
+ This function filters results based on the specified file_type
+ and uploads the file to the Files API.
+ """
+ output_lines = [json.dumps(result) for result in results]
+
+ with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
+ file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
+ uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
+ return uploaded_file.id
diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py
new file mode 100644
index 000000000..d8d06868b
--- /dev/null
+++ b/llama_stack/providers/inline/batches/reference/config.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from pydantic import BaseModel, Field
+
+from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
+
+
+class ReferenceBatchesImplConfig(BaseModel):
+ """Configuration for the Reference Batches implementation."""
+
+ kvstore: KVStoreConfig = Field(
+ description="Configuration for the key-value store backend.",
+ )
+
+ max_concurrent_batches: int = Field(
+ default=1,
+ description="Maximum number of concurrent batches to process simultaneously.",
+ ge=1,
+ )
+
+ max_concurrent_requests_per_batch: int = Field(
+ default=10,
+ description="Maximum number of concurrent requests to process per batch.",
+ ge=1,
+ )
+
+ # TODO: add a max requests per second rate limiter
+
+ @classmethod
+ def sample_run_config(cls, __distro_dir__: str) -> dict:
+ return {
+ "kvstore": SqliteKVStoreConfig.sample_run_config(
+ __distro_dir__=__distro_dir__,
+ db_name="batches.db",
+ ),
+ }
diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py
new file mode 100644
index 000000000..de7886efb
--- /dev/null
+++ b/llama_stack/providers/registry/batches.py
@@ -0,0 +1,26 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+
+from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
+
+
+def available_providers() -> list[ProviderSpec]:
+ return [
+ InlineProviderSpec(
+ api=Api.batches,
+ provider_type="inline::reference",
+ pip_packages=["openai"],
+ module="llama_stack.providers.inline.batches.reference",
+ config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
+ api_dependencies=[
+ Api.inference,
+ Api.files,
+ Api.models,
+ ],
+ description="Reference implementation of batches API with KVStore persistence.",
+ ),
+ ]
diff --git a/tests/integration/batches/__init__.py b/tests/integration/batches/__init__.py
new file mode 100644
index 000000000..756f351d8
--- /dev/null
+++ b/tests/integration/batches/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
diff --git a/tests/integration/batches/conftest.py b/tests/integration/batches/conftest.py
new file mode 100644
index 000000000..974fe77ab
--- /dev/null
+++ b/tests/integration/batches/conftest.py
@@ -0,0 +1,122 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+"""Shared pytest fixtures for batch tests."""
+
+import json
+import time
+import warnings
+from contextlib import contextmanager
+from io import BytesIO
+
+import pytest
+
+from llama_stack.apis.files import OpenAIFilePurpose
+
+
+class BatchHelper:
+ """Helper class for creating and managing batch input files."""
+
+ def __init__(self, client):
+ """Initialize with either a batch_client or openai_client."""
+ self.client = client
+
+ @contextmanager
+ def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
+ """Context manager for creating and cleaning up batch input files.
+
+ Args:
+ content: Either a list of batch request dictionaries or raw string content
+ filename_prefix: Prefix for the generated filename (or full filename if content is string)
+
+ Yields:
+ The uploaded file object
+ """
+ if isinstance(content, str):
+ # Handle raw string content (e.g., malformed JSONL, empty files)
+ file_content = content.encode("utf-8")
+ else:
+ # Handle list of batch request dictionaries
+ jsonl_content = "\n".join(json.dumps(req) for req in content)
+ file_content = jsonl_content.encode("utf-8")
+
+ filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
+
+ with BytesIO(file_content) as file_buffer:
+ file_buffer.name = filename
+ uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
+
+ try:
+ yield uploaded_file
+ finally:
+ try:
+ self.client.files.delete(uploaded_file.id)
+ except Exception:
+ warnings.warn(
+ f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
+ stacklevel=2,
+ )
+
+ def wait_for(
+ self,
+ batch_id: str,
+ max_wait_time: int = 60,
+ sleep_interval: int | None = None,
+ expected_statuses: set[str] | None = None,
+ timeout_action: str = "fail",
+ ):
+ """Wait for a batch to reach a terminal status.
+
+ Args:
+ batch_id: The batch ID to monitor
+ max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
+ sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
+ expected_statuses: Set of expected terminal statuses (default: {"completed"})
+ timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
+
+ Returns:
+ The final batch object
+
+ Raises:
+ pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
+ pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
+ """
+ if sleep_interval is None:
+ # Default to 1/10th of max_wait_time, with min 1s and max 15s
+ sleep_interval = max(1, min(15, max_wait_time // 10))
+
+ if expected_statuses is None:
+ expected_statuses = {"completed"}
+
+ terminal_statuses = {"completed", "failed", "cancelled", "expired"}
+ unexpected_statuses = terminal_statuses - expected_statuses
+
+ start_time = time.time()
+ while time.time() - start_time < max_wait_time:
+ current_batch = self.client.batches.retrieve(batch_id)
+
+ if current_batch.status in expected_statuses:
+ return current_batch
+ elif current_batch.status in unexpected_statuses:
+ error_msg = f"Batch reached unexpected status: {current_batch.status}"
+ if timeout_action == "skip":
+ pytest.skip(error_msg)
+ else:
+ pytest.fail(error_msg)
+
+ time.sleep(sleep_interval)
+
+ timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
+ if timeout_action == "skip":
+ pytest.skip(timeout_msg)
+ else:
+ pytest.fail(timeout_msg)
+
+
+@pytest.fixture
+def batch_helper(openai_client):
+ """Fixture that provides a BatchHelper instance for OpenAI client."""
+ return BatchHelper(openai_client)
diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py
new file mode 100644
index 000000000..1ef3202d0
--- /dev/null
+++ b/tests/integration/batches/test_batches.py
@@ -0,0 +1,270 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+"""
+Integration tests for the Llama Stack batch processing functionality.
+
+This module contains comprehensive integration tests for the batch processing API,
+using the OpenAI-compatible client interface for consistency.
+
+Test Categories:
+ 1. Core Batch Operations:
+ - test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval
+ - test_batch_listing: Basic batch listing functionality
+ - test_batch_immediate_cancellation: Batch cancellation workflow
+ # TODO: cancel during processing
+
+ 2. End-to-End Processing:
+ - test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation
+
+Note: Error conditions and edge cases are primarily tested in test_batches_errors.py
+for better organization and separation of concerns.
+
+CLEANUP WARNING: These tests currently create batches that are not automatically
+cleaned up after test completion. This may lead to resource accumulation over
+multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch.
+The test_batch_e2e_chat_completions test does clean up its output and error files.
+"""
+
+import json
+
+
+class TestBatchesIntegration:
+ """Integration tests for the batches API."""
+
+ def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id):
+ """Test comprehensive batch creation and retrieval scenarios."""
+ test_metadata = {
+ "test_type": "comprehensive",
+ "purpose": "creation_and_retrieval_test",
+ "version": "1.0",
+ "tags": "test,batch",
+ }
+
+ batch_requests = [
+ {
+ "custom_id": "request-1",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 10,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata=test_metadata,
+ )
+
+ assert batch.endpoint == "/v1/chat/completions"
+ assert batch.input_file_id == uploaded_file.id
+ assert batch.completion_window == "24h"
+ assert batch.metadata == test_metadata
+
+ retrieved_batch = openai_client.batches.retrieve(batch.id)
+
+ assert retrieved_batch.id == batch.id
+ assert retrieved_batch.object == batch.object
+ assert retrieved_batch.endpoint == batch.endpoint
+ assert retrieved_batch.input_file_id == batch.input_file_id
+ assert retrieved_batch.completion_window == batch.completion_window
+ assert retrieved_batch.metadata == batch.metadata
+
+ def test_batch_listing(self, openai_client, batch_helper, text_model_id):
+ """
+ Test batch listing.
+
+ This test creates multiple batches and verifies that they can be listed.
+ It also deletes the input files before execution, which means the batches
+ will appear as failed due to missing input files. This is expected and
+ a good thing, because it means no inference is performed.
+ """
+ batch_ids = []
+
+ for i in range(2):
+ batch_requests = [
+ {
+ "custom_id": f"request-{i}",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": f"Hello {i}"}],
+ "max_tokens": 10,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+ batch_ids.append(batch.id)
+
+ batch_list = openai_client.batches.list()
+
+ assert isinstance(batch_list.data, list)
+
+ listed_batch_ids = {b.id for b in batch_list.data}
+ for batch_id in batch_ids:
+ assert batch_id in listed_batch_ids
+
+ def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id):
+ """Test immediate batch cancellation."""
+ batch_requests = [
+ {
+ "custom_id": "request-1",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 10,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests) as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ # hopefully cancel the batch before it completes
+ cancelling_batch = openai_client.batches.cancel(batch.id)
+ assert cancelling_batch.status in ["cancelling", "cancelled"]
+ assert isinstance(cancelling_batch.cancelling_at, int), (
+ f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}"
+ )
+
+ final_batch = batch_helper.wait_for(
+ batch.id,
+ max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min
+ expected_statuses={"cancelled"},
+ timeout_action="skip",
+ )
+
+ assert final_batch.status == "cancelled"
+ assert isinstance(final_batch.cancelled_at, int), (
+ f"cancelled_at should be int, got {type(final_batch.cancelled_at)}"
+ )
+
+ def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id):
+ """Test end-to-end batch processing for chat completions with both successful and failed operations."""
+ batch_requests = [
+ {
+ "custom_id": "success-1",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Say hello"}],
+ "max_tokens": 20,
+ },
+ },
+ {
+ "custom_id": "error-1",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "This should fail"}],
+ "max_tokens": -1, # Invalid negative max_tokens will cause inference error
+ },
+ },
+ ]
+
+ with batch_helper.create_file(batch_requests) as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={"test": "e2e_success_and_errors_test"},
+ )
+
+ final_batch = batch_helper.wait_for(
+ batch.id,
+ max_wait_time=3 * 60, # often takes 2-3 minutes
+ expected_statuses={"completed"},
+ timeout_action="skip",
+ )
+
+ # Expecting a completed batch with both successful and failed requests
+ # Batch(id='batch_xxx',
+ # completion_window='24h',
+ # created_at=...,
+ # endpoint='/v1/chat/completions',
+ # input_file_id='file-xxx',
+ # object='batch',
+ # status='completed',
+ # output_file_id='file-xxx',
+ # error_file_id='file-xxx',
+ # request_counts=BatchRequestCounts(completed=1, failed=1, total=2))
+
+ assert final_batch.status == "completed"
+ assert final_batch.request_counts is not None
+ assert final_batch.request_counts.total == 2
+ assert final_batch.request_counts.completed == 1
+ assert final_batch.request_counts.failed == 1
+
+ assert final_batch.output_file_id is not None, "Output file should exist for successful requests"
+
+ output_content = openai_client.files.content(final_batch.output_file_id)
+ if isinstance(output_content, str):
+ output_text = output_content
+ else:
+ output_text = output_content.content.decode("utf-8")
+
+ output_lines = output_text.strip().split("\n")
+
+ for line in output_lines:
+ result = json.loads(line)
+
+ assert "id" in result
+ assert "custom_id" in result
+ assert result["custom_id"] == "success-1"
+
+ assert "response" in result
+
+ assert result["response"]["status_code"] == 200
+ assert "body" in result["response"]
+ assert "choices" in result["response"]["body"]
+
+ assert final_batch.error_file_id is not None, "Error file should exist for failed requests"
+
+ error_content = openai_client.files.content(final_batch.error_file_id)
+ if isinstance(error_content, str):
+ error_text = error_content
+ else:
+ error_text = error_content.content.decode("utf-8")
+
+ error_lines = error_text.strip().split("\n")
+
+ for line in error_lines:
+ result = json.loads(line)
+
+ assert "id" in result
+ assert "custom_id" in result
+ assert result["custom_id"] == "error-1"
+ assert "error" in result
+ error = result["error"]
+ assert error is not None
+ assert "code" in error or "message" in error, "Error should have code or message"
+
+ deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
+ assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully"
+
+ deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
+ assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py
new file mode 100644
index 000000000..2cd1e561e
--- /dev/null
+++ b/tests/integration/batches/test_batches_errors.py
@@ -0,0 +1,694 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+"""
+Error handling and edge case tests for the Llama Stack batch processing functionality.
+
+This module focuses exclusively on testing error conditions, validation failures,
+and edge cases for batch operations to ensure robust error handling and graceful
+degradation.
+
+Test Categories:
+ 1. File and Input Validation:
+ - test_batch_nonexistent_file_id: Handling invalid file IDs
+ - test_batch_malformed_jsonl: Processing malformed JSONL input files
+ - test_file_malformed_batch_file: Handling malformed files at upload time
+ - test_batch_missing_required_fields: Validation of required request fields
+
+ 2. API Endpoint and Model Validation:
+ - test_batch_invalid_endpoint: Invalid endpoint handling during creation
+ - test_batch_error_handling_invalid_model: Error handling with nonexistent models
+ - test_batch_endpoint_mismatch: Validation of endpoint/URL consistency
+
+ 3. Batch Lifecycle Error Handling:
+ - test_batch_retrieve_nonexistent: Retrieving non-existent batches
+ - test_batch_cancel_nonexistent: Cancelling non-existent batches
+ - test_batch_cancel_completed: Attempting to cancel completed batches
+
+ 4. Parameter and Configuration Validation:
+ - test_batch_invalid_completion_window: Invalid completion window values
+ - test_batch_invalid_metadata_types: Invalid metadata type validation
+ - test_batch_missing_required_body_fields: Validation of required fields in request body
+
+ 5. Feature Restriction and Compatibility:
+ - test_batch_streaming_not_supported: Streaming request rejection
+ - test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation
+
+Note: Core functionality and OpenAI compatibility tests are located in
+test_batches_integration.py for better organization and separation of concerns.
+
+CLEANUP WARNING: These tests create batches to test error conditions but do not
+automatically clean them up after test completion. While most error tests create
+batches that fail quickly, some may create valid batches that consume resources.
+"""
+
+import pytest
+from openai import BadRequestError, ConflictError, NotFoundError
+
+
+class TestBatchesErrorHandling:
+ """Error handling and edge case tests for the batches API using OpenAI client."""
+
+ def test_batch_nonexistent_file_id(self, openai_client, batch_helper):
+ """Test batch creation with nonexistent input file ID."""
+
+ batch = openai_client.batches.create(
+ input_file_id="file-nonexistent-xyz",
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(data=[
+ # BatchError(
+ # code='invalid_request',
+ # line=None,
+ # message='Cannot find file ..., or organization ... does not have access to it.',
+ # param='file_id')
+ # ], object='list'),
+ # failed_at=1754566971,
+ # ...)
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) == 1
+ error = final_batch.errors.data[0]
+ assert error.code == "invalid_request"
+ assert "cannot find file" in error.message.lower()
+
+ def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id):
+ """Test batch creation with invalid endpoint."""
+ batch_requests = [
+ {
+ "custom_id": "invalid-endpoint",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 10,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests) as uploaded_file:
+ with pytest.raises(BadRequestError) as exc_info:
+ openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/invalid/endpoint",
+ completion_window="24h",
+ )
+
+ # Expected -
+ # Error code: 400 - {
+ # 'error': {
+ # 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.",
+ # 'type': 'invalid_request_error',
+ # 'param': 'endpoint',
+ # 'code': 'invalid_value'
+ # }
+ # }
+
+ error_msg = str(exc_info.value).lower()
+ assert exc_info.value.status_code == 400
+ assert "invalid value" in error_msg
+ assert "/v1/invalid/endpoint" in error_msg
+ assert "supported values" in error_msg
+ assert "endpoint" in error_msg
+ assert "invalid_value" in error_msg
+
+ def test_batch_malformed_jsonl(self, openai_client, batch_helper):
+ """
+ Test batch with malformed JSONL input.
+
+ The /v1/files endpoint requires valid JSONL format, so we provide a well formed line
+ before a malformed line to ensure we get to the /v1/batches validation stage.
+ """
+ with batch_helper.create_file(
+ """{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}}
+{invalid json here""",
+ "malformed_batch_input.jsonl",
+ ) as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(data=[
+ # ...,
+ # BatchError(code='invalid_json_line',
+ # line=2,
+ # message='This line is not parseable as valid JSON.',
+ # param=None)
+ # ], object='list'),
+ # ...)
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) > 0
+ error = final_batch.errors.data[-1] # get last error because first may be about the "test" model
+ assert error.code == "invalid_json_line"
+ assert error.line == 2
+ assert "not" in error.message.lower()
+ assert "valid json" in error.message.lower()
+
+ @pytest.mark.xfail(reason="Not all file providers validate content")
+ @pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"])
+ def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests):
+ """Test file upload with malformed content."""
+
+ with pytest.raises(BadRequestError) as exc_info:
+ with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"):
+ # /v1/files rejects the file, we don't get to batch creation
+ pass
+
+ error_msg = str(exc_info.value).lower()
+ assert exc_info.value.status_code == 400
+ assert "invalid file format" in error_msg
+ assert "jsonl" in error_msg
+
+ def test_batch_retrieve_nonexistent(self, openai_client):
+ """Test retrieving nonexistent batch."""
+ with pytest.raises(NotFoundError) as exc_info:
+ openai_client.batches.retrieve("batch-nonexistent-xyz")
+
+ error_msg = str(exc_info.value).lower()
+ assert exc_info.value.status_code == 404
+ assert "no batch found" in error_msg or "not found" in error_msg
+
+ def test_batch_cancel_nonexistent(self, openai_client):
+ """Test cancelling nonexistent batch."""
+ with pytest.raises(NotFoundError) as exc_info:
+ openai_client.batches.cancel("batch-nonexistent-xyz")
+
+ error_msg = str(exc_info.value).lower()
+ assert exc_info.value.status_code == 404
+ assert "no batch found" in error_msg or "not found" in error_msg
+
+ def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id):
+ """Test cancelling already completed batch."""
+ batch_requests = [
+ {
+ "custom_id": "cancel-completed",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Quick test"}],
+ "max_tokens": 5,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(
+ batch.id,
+ max_wait_time=3 * 60, # often take 10-11 min, give it 3 min
+ expected_statuses={"completed"},
+ timeout_action="skip",
+ )
+
+ deleted_file = openai_client.files.delete(final_batch.output_file_id)
+ assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully"
+
+ with pytest.raises(ConflictError) as exc_info:
+ openai_client.batches.cancel(batch.id)
+
+ # Expecting -
+ # Error code: 409 - {
+ # 'error': {
+ # 'message': "Cannot cancel a batch with status 'completed'.",
+ # 'type': 'invalid_request_error',
+ # 'param': None,
+ # 'code': None
+ # }
+ # }
+ #
+ # NOTE: Same for "failed", cancelling "cancelled" batches is allowed
+
+ error_msg = str(exc_info.value).lower()
+ assert exc_info.value.status_code == 409
+ assert "cannot cancel" in error_msg
+
+ def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id):
+ """Test batch with requests missing required fields."""
+ batch_requests = [
+ {
+ # Missing custom_id
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "No custom_id"}],
+ "max_tokens": 10,
+ },
+ },
+ {
+ "custom_id": "no-method",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "No method"}],
+ "max_tokens": 10,
+ },
+ },
+ {
+ "custom_id": "no-url",
+ "method": "POST",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "No URL"}],
+ "max_tokens": 10,
+ },
+ },
+ {
+ "custom_id": "no-body",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ },
+ ]
+
+ with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(
+ # data=[
+ # BatchError(
+ # code='missing_required_parameter',
+ # line=1,
+ # message="Missing required parameter: 'custom_id'.",
+ # param='custom_id'
+ # ),
+ # BatchError(
+ # code='missing_required_parameter',
+ # line=2,
+ # message="Missing required parameter: 'method'.",
+ # param='method'
+ # ),
+ # BatchError(
+ # code='missing_required_parameter',
+ # line=3,
+ # message="Missing required parameter: 'url'.",
+ # param='url'
+ # ),
+ # BatchError(
+ # code='missing_required_parameter',
+ # line=4,
+ # message="Missing required parameter: 'body'.",
+ # param='body'
+ # )
+ # ], object='list'),
+ # failed_at=1754566945,
+ # ...)
+ # )
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) == 4
+ no_custom_id_error = final_batch.errors.data[0]
+ assert no_custom_id_error.code == "missing_required_parameter"
+ assert no_custom_id_error.line == 1
+ assert "missing" in no_custom_id_error.message.lower()
+ assert "custom_id" in no_custom_id_error.message.lower()
+ no_method_error = final_batch.errors.data[1]
+ assert no_method_error.code == "missing_required_parameter"
+ assert no_method_error.line == 2
+ assert "missing" in no_method_error.message.lower()
+ assert "method" in no_method_error.message.lower()
+ no_url_error = final_batch.errors.data[2]
+ assert no_url_error.code == "missing_required_parameter"
+ assert no_url_error.line == 3
+ assert "missing" in no_url_error.message.lower()
+ assert "url" in no_url_error.message.lower()
+ no_body_error = final_batch.errors.data[3]
+ assert no_body_error.code == "missing_required_parameter"
+ assert no_body_error.line == 4
+ assert "missing" in no_body_error.message.lower()
+ assert "body" in no_body_error.message.lower()
+
+ def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id):
+ """Test batch creation with invalid completion window."""
+ batch_requests = [
+ {
+ "custom_id": "invalid-completion-window",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 10,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests) as uploaded_file:
+ for window in ["1h", "48h", "invalid", ""]:
+ with pytest.raises(BadRequestError) as exc_info:
+ openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window=window,
+ )
+ assert exc_info.value.status_code == 400
+ error_msg = str(exc_info.value).lower()
+ assert "invalid value" in error_msg
+ assert "completion_window" in error_msg
+ assert "supported values are" in error_msg
+
+ def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id):
+ """Test that streaming responses are not supported in batches."""
+ batch_requests = [
+ {
+ "custom_id": "streaming-test",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 10,
+ "stream": True, # Not supported
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(data=[
+ # BatchError(code='streaming_unsupported',
+ # line=1,
+ # message='Chat Completions: Streaming is not supported in the Batch API.',
+ # param='body.stream')
+ # ], object='list'),
+ # failed_at=1754566965,
+ # ...)
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) == 1
+ error = final_batch.errors.data[0]
+ assert error.code == "streaming_unsupported"
+ assert error.line == 1
+ assert "streaming" in error.message.lower()
+ assert "not supported" in error.message.lower()
+ assert error.param == "body.stream"
+ assert final_batch.failed_at is not None
+
+ def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id):
+ """
+ Test batch with mixed streaming and non-streaming requests.
+
+ This is distinct from test_batch_streaming_not_supported, which tests a single
+ streaming request, to ensure an otherwise valid batch fails when a single
+ streaming request is included.
+ """
+ batch_requests = [
+ {
+ "custom_id": "valid-non-streaming-request",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello without streaming"}],
+ "max_tokens": 10,
+ },
+ },
+ {
+ "custom_id": "streaming-request",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello with streaming"}],
+ "max_tokens": 10,
+ "stream": True, # Not supported
+ },
+ },
+ ]
+
+ with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(data=[
+ # BatchError(
+ # code='streaming_unsupported',
+ # line=2,
+ # message='Chat Completions: Streaming is not supported in the Batch API.',
+ # param='body.stream')
+ # ], object='list'),
+ # failed_at=1754574442,
+ # ...)
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) == 1
+ error = final_batch.errors.data[0]
+ assert error.code == "streaming_unsupported"
+ assert error.line == 2
+ assert "streaming" in error.message.lower()
+ assert "not supported" in error.message.lower()
+ assert error.param == "body.stream"
+ assert final_batch.failed_at is not None
+
+ def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id):
+ """Test batch creation with mismatched endpoint and request URL."""
+ batch_requests = [
+ {
+ "custom_id": "endpoint-mismatch",
+ "method": "POST",
+ "url": "/v1/embeddings", # Different from batch endpoint
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello"}],
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions", # Different from request URL
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(data=[
+ # BatchError(
+ # code='invalid_url',
+ # line=1,
+ # message='The URL provided for this request does not match the batch endpoint.',
+ # param='url')
+ # ], object='list'),
+ # failed_at=1754566972,
+ # ...)
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) == 1
+ error = final_batch.errors.data[0]
+ assert error.line == 1
+ assert error.code == "invalid_url"
+ assert "does not match" in error.message.lower()
+ assert "endpoint" in error.message.lower()
+ assert final_batch.failed_at is not None
+
+ def test_batch_error_handling_invalid_model(self, openai_client, batch_helper):
+ """Test batch error handling with invalid model."""
+ batch_requests = [
+ {
+ "custom_id": "invalid-model",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": "nonexistent-model-xyz",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 10,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests) as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(data=[
+ # BatchError(code='model_not_found',
+ # line=1,
+ # message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.",
+ # param='body.model')
+ # ], object='list'),
+ # failed_at=1754566978,
+ # ...)
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) == 1
+ error = final_batch.errors.data[0]
+ assert error.line == 1
+ assert error.code == "model_not_found"
+ assert "not supported" in error.message.lower()
+ assert error.param == "body.model"
+ assert final_batch.failed_at is not None
+
+ def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id):
+ """Test batch with requests missing required fields in body (model and messages)."""
+ batch_requests = [
+ {
+ "custom_id": "missing-model",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ # Missing model field
+ "messages": [{"role": "user", "content": "Hello without model"}],
+ "max_tokens": 10,
+ },
+ },
+ {
+ "custom_id": "missing-messages",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ # Missing messages field
+ "max_tokens": 10,
+ },
+ },
+ ]
+
+ with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file:
+ batch = openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ )
+
+ final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
+
+ # Expecting -
+ # Batch(...,
+ # status='failed',
+ # errors=Errors(data=[
+ # BatchError(
+ # code='invalid_request',
+ # line=1,
+ # message='Model parameter is required.',
+ # param='body.model'),
+ # BatchError(
+ # code='invalid_request',
+ # line=2,
+ # message='Messages parameter is required.',
+ # param='body.messages')
+ # ], object='list'),
+ # ...)
+
+ assert final_batch.status == "failed"
+ assert final_batch.errors is not None
+ assert len(final_batch.errors.data) == 2
+
+ model_error = final_batch.errors.data[0]
+ assert model_error.line == 1
+ assert "model" in model_error.message.lower()
+ assert model_error.param == "body.model"
+
+ messages_error = final_batch.errors.data[1]
+ assert messages_error.line == 2
+ assert "messages" in messages_error.message.lower()
+ assert messages_error.param == "body.messages"
+
+ assert final_batch.failed_at is not None
+
+ def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id):
+ """Test batch creation with invalid metadata types (like lists)."""
+ batch_requests = [
+ {
+ "custom_id": "invalid-metadata-type",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {
+ "model": text_model_id,
+ "messages": [{"role": "user", "content": "Hello"}],
+ "max_tokens": 10,
+ },
+ }
+ ]
+
+ with batch_helper.create_file(batch_requests) as uploaded_file:
+ with pytest.raises(Exception) as exc_info:
+ openai_client.batches.create(
+ input_file_id=uploaded_file.id,
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={
+ "tags": ["tag1", "tag2"], # Invalid type, should be a string
+ },
+ )
+
+ # Expecting -
+ # Error code: 400 - {'error':
+ # {'message': "Invalid type for 'metadata.tags': expected a string,
+ # but got an array instead.",
+ # 'type': 'invalid_request_error', 'param': 'metadata.tags',
+ # 'code': 'invalid_type'}}
+
+ error_msg = str(exc_info.value).lower()
+ assert "400" in error_msg
+ assert "tags" in error_msg
+ assert "string" in error_msg
diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py
new file mode 100644
index 000000000..9fe0cc710
--- /dev/null
+++ b/tests/unit/providers/batches/test_reference.py
@@ -0,0 +1,753 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+"""
+Test suite for the reference implementation of the Batches API.
+
+The tests are categorized and outlined below, keep this updated:
+
+- Batch creation with various parameters and validation:
+ * test_create_and_retrieve_batch_success (positive)
+ * test_create_batch_without_metadata (positive)
+ * test_create_batch_completion_window (negative)
+ * test_create_batch_invalid_endpoints (negative)
+ * test_create_batch_invalid_metadata (negative)
+
+- Batch retrieval and error handling for non-existent batches:
+ * test_retrieve_batch_not_found (negative)
+
+- Batch cancellation with proper status transitions:
+ * test_cancel_batch_success (positive)
+ * test_cancel_batch_invalid_statuses (negative)
+ * test_cancel_batch_not_found (negative)
+
+- Batch listing with pagination and filtering:
+ * test_list_batches_empty (positive)
+ * test_list_batches_single_batch (positive)
+ * test_list_batches_multiple_batches (positive)
+ * test_list_batches_with_limit (positive)
+ * test_list_batches_with_pagination (positive)
+ * test_list_batches_invalid_after (negative)
+
+- Data persistence in the underlying key-value store:
+ * test_kvstore_persistence (positive)
+
+- Batch processing concurrency control:
+ * test_max_concurrent_batches (positive)
+
+- Input validation testing (direct _validate_input method tests):
+ * test_validate_input_file_not_found (negative)
+ * test_validate_input_file_exists_empty_content (positive)
+ * test_validate_input_file_mixed_valid_invalid_json (mixed)
+ * test_validate_input_invalid_model (negative)
+ * test_validate_input_url_mismatch (negative)
+ * test_validate_input_multiple_errors_per_request (negative)
+ * test_validate_input_invalid_request_format (negative)
+ * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
+ * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
+
+The tests use temporary SQLite databases for isolation and mock external
+dependencies like inference, files, and models APIs.
+"""
+
+import json
+import tempfile
+from pathlib import Path
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
+from llama_stack.apis.batches import BatchObject
+from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
+from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
+from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
+from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
+
+
+class TestReferenceBatchesImpl:
+ """Test the reference implementation of the Batches API."""
+
+ @pytest.fixture
+ async def provider(self):
+ """Create a test provider instance with temporary database."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ db_path = Path(tmpdir) / "test_batches.db"
+ kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
+ config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
+
+ # Create kvstore and mock APIs
+ from unittest.mock import AsyncMock
+
+ from llama_stack.providers.utils.kvstore import kvstore_impl
+
+ kvstore = await kvstore_impl(config.kvstore)
+ mock_inference = AsyncMock()
+ mock_files = AsyncMock()
+ mock_models = AsyncMock()
+
+ provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
+ await provider.initialize()
+
+ # unit tests should not require background processing
+ provider.process_batches = False
+
+ yield provider
+
+ await provider.shutdown()
+
+ @pytest.fixture
+ def sample_batch_data(self):
+ """Sample batch data for testing."""
+ return {
+ "input_file_id": "file_abc123",
+ "endpoint": "/v1/chat/completions",
+ "completion_window": "24h",
+ "metadata": {"test": "true", "priority": "high"},
+ }
+
+ def _validate_batch_type(self, batch, expected_metadata=None):
+ """
+ Helper function to validate batch object structure and field types.
+
+ Note: This validates the direct BatchObject from the provider, not the
+ client library response which has a different structure.
+
+ Args:
+ batch: The BatchObject instance to validate.
+ expected_metadata: Optional expected metadata dictionary to validate against.
+ """
+ assert isinstance(batch.id, str)
+ assert isinstance(batch.completion_window, str)
+ assert isinstance(batch.created_at, int)
+ assert isinstance(batch.endpoint, str)
+ assert isinstance(batch.input_file_id, str)
+ assert batch.object == "batch"
+ assert batch.status in [
+ "validating",
+ "failed",
+ "in_progress",
+ "finalizing",
+ "completed",
+ "expired",
+ "cancelling",
+ "cancelled",
+ ]
+
+ if expected_metadata is not None:
+ assert batch.metadata == expected_metadata
+
+ timestamp_fields = [
+ "cancelled_at",
+ "cancelling_at",
+ "completed_at",
+ "expired_at",
+ "expires_at",
+ "failed_at",
+ "finalizing_at",
+ "in_progress_at",
+ ]
+ for field in timestamp_fields:
+ field_value = getattr(batch, field, None)
+ if field_value is not None:
+ assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
+
+ file_id_fields = ["error_file_id", "output_file_id"]
+ for field in file_id_fields:
+ field_value = getattr(batch, field, None)
+ if field_value is not None:
+ assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
+
+ if hasattr(batch, "request_counts") and batch.request_counts is not None:
+ assert isinstance(batch.request_counts.completed, int), (
+ f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
+ )
+ assert isinstance(batch.request_counts.failed, int), (
+ f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
+ )
+ assert isinstance(batch.request_counts.total, int), (
+ f"request_counts.total should be int, got {type(batch.request_counts.total)}"
+ )
+
+ if hasattr(batch, "errors") and batch.errors is not None:
+ assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
+
+ if hasattr(batch.errors, "data") and batch.errors.data is not None:
+ assert isinstance(batch.errors.data, list), (
+ f"errors.data should be list or None, got {type(batch.errors.data)}"
+ )
+
+ for i, error_item in enumerate(batch.errors.data):
+ assert isinstance(error_item, dict), (
+ f"errors.data[{i}] should be object or dict, got {type(error_item)}"
+ )
+
+ if hasattr(error_item, "code") and error_item.code is not None:
+ assert isinstance(error_item.code, str), (
+ f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
+ )
+
+ if hasattr(error_item, "line") and error_item.line is not None:
+ assert isinstance(error_item.line, int), (
+ f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
+ )
+
+ if hasattr(error_item, "message") and error_item.message is not None:
+ assert isinstance(error_item.message, str), (
+ f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
+ )
+
+ if hasattr(error_item, "param") and error_item.param is not None:
+ assert isinstance(error_item.param, str), (
+ f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
+ )
+
+ if hasattr(batch.errors, "object") and batch.errors.object is not None:
+ assert isinstance(batch.errors.object, str), (
+ f"errors.object should be str or None, got {type(batch.errors.object)}"
+ )
+ assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
+
+ async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
+ """Test successful batch creation and retrieval."""
+ created_batch = await provider.create_batch(**sample_batch_data)
+
+ self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
+
+ assert created_batch.id.startswith("batch_")
+ assert len(created_batch.id) > 13
+ assert created_batch.object == "batch"
+ assert created_batch.endpoint == sample_batch_data["endpoint"]
+ assert created_batch.input_file_id == sample_batch_data["input_file_id"]
+ assert created_batch.completion_window == sample_batch_data["completion_window"]
+ assert created_batch.status == "validating"
+ assert created_batch.metadata == sample_batch_data["metadata"]
+ assert isinstance(created_batch.created_at, int)
+ assert created_batch.created_at > 0
+
+ retrieved_batch = await provider.retrieve_batch(created_batch.id)
+
+ self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
+
+ assert retrieved_batch.id == created_batch.id
+ assert retrieved_batch.input_file_id == created_batch.input_file_id
+ assert retrieved_batch.endpoint == created_batch.endpoint
+ assert retrieved_batch.status == created_batch.status
+ assert retrieved_batch.metadata == created_batch.metadata
+
+ async def test_create_batch_without_metadata(self, provider):
+ """Test batch creation without optional metadata."""
+ batch = await provider.create_batch(
+ input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
+ )
+
+ assert batch.metadata is None
+
+ async def test_create_batch_completion_window(self, provider):
+ """Test batch creation with invalid completion window."""
+ with pytest.raises(ValueError, match="Invalid completion_window"):
+ await provider.create_batch(
+ input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
+ )
+
+ @pytest.mark.parametrize(
+ "endpoint",
+ [
+ "/v1/embeddings",
+ "/v1/completions",
+ "/v1/invalid/endpoint",
+ "",
+ ],
+ )
+ async def test_create_batch_invalid_endpoints(self, provider, endpoint):
+ """Test batch creation with various invalid endpoints."""
+ with pytest.raises(ValueError, match="Invalid endpoint"):
+ await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
+
+ async def test_create_batch_invalid_metadata(self, provider):
+ """Test that batch creation fails with invalid metadata."""
+ with pytest.raises(ValueError, match="should be a valid string"):
+ await provider.create_batch(
+ input_file_id="file_123",
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={123: "invalid_key"}, # Non-string key
+ )
+
+ with pytest.raises(ValueError, match="should be a valid string"):
+ await provider.create_batch(
+ input_file_id="file_123",
+ endpoint="/v1/chat/completions",
+ completion_window="24h",
+ metadata={"valid_key": 456}, # Non-string value
+ )
+
+ async def test_retrieve_batch_not_found(self, provider):
+ """Test error when retrieving non-existent batch."""
+ with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
+ await provider.retrieve_batch("nonexistent_batch")
+
+ async def test_cancel_batch_success(self, provider, sample_batch_data):
+ """Test successful batch cancellation."""
+ created_batch = await provider.create_batch(**sample_batch_data)
+ assert created_batch.status == "validating"
+
+ cancelled_batch = await provider.cancel_batch(created_batch.id)
+
+ assert cancelled_batch.id == created_batch.id
+ assert cancelled_batch.status in ["cancelling", "cancelled"]
+ assert isinstance(cancelled_batch.cancelling_at, int)
+ assert cancelled_batch.cancelling_at >= created_batch.created_at
+
+ @pytest.mark.parametrize("status", ["failed", "expired", "completed"])
+ async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
+ """Test error when cancelling batch in final states."""
+ provider.process_batches = False
+ created_batch = await provider.create_batch(**sample_batch_data)
+
+ # directly update status in kvstore
+ await provider._update_batch(created_batch.id, status=status)
+
+ with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
+ await provider.cancel_batch(created_batch.id)
+
+ async def test_cancel_batch_not_found(self, provider):
+ """Test error when cancelling non-existent batch."""
+ with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
+ await provider.cancel_batch("nonexistent_batch")
+
+ async def test_list_batches_empty(self, provider):
+ """Test listing batches when none exist."""
+ response = await provider.list_batches()
+
+ assert response.object == "list"
+ assert response.data == []
+ assert response.first_id is None
+ assert response.last_id is None
+ assert response.has_more is False
+
+ async def test_list_batches_single_batch(self, provider, sample_batch_data):
+ """Test listing batches with single batch."""
+ created_batch = await provider.create_batch(**sample_batch_data)
+
+ response = await provider.list_batches()
+
+ assert len(response.data) == 1
+ self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
+ assert response.data[0].id == created_batch.id
+ assert response.first_id == created_batch.id
+ assert response.last_id == created_batch.id
+ assert response.has_more is False
+
+ async def test_list_batches_multiple_batches(self, provider):
+ """Test listing multiple batches."""
+ batches = [
+ await provider.create_batch(
+ input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
+ )
+ for i in range(3)
+ ]
+
+ response = await provider.list_batches()
+
+ assert len(response.data) == 3
+
+ batch_ids = {batch.id for batch in response.data}
+ expected_ids = {batch.id for batch in batches}
+ assert batch_ids == expected_ids
+ assert response.has_more is False
+
+ assert response.first_id in expected_ids
+ assert response.last_id in expected_ids
+
+ async def test_list_batches_with_limit(self, provider):
+ """Test listing batches with limit parameter."""
+ batches = [
+ await provider.create_batch(
+ input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
+ )
+ for i in range(3)
+ ]
+
+ response = await provider.list_batches(limit=2)
+
+ assert len(response.data) == 2
+ assert response.has_more is True
+ assert response.first_id == response.data[0].id
+ assert response.last_id == response.data[1].id
+ batch_ids = {batch.id for batch in response.data}
+ expected_ids = {batch.id for batch in batches}
+ assert batch_ids.issubset(expected_ids)
+
+ async def test_list_batches_with_pagination(self, provider):
+ """Test listing batches with pagination using 'after' parameter."""
+ for i in range(3):
+ await provider.create_batch(
+ input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
+ )
+
+ # Get first page
+ first_page = await provider.list_batches(limit=1)
+ assert len(first_page.data) == 1
+ assert first_page.has_more is True
+
+ # Get second page using 'after'
+ second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
+ assert len(second_page.data) == 1
+ assert second_page.data[0].id != first_page.data[0].id
+
+ # Verify we got the next batch in order
+ all_batches = await provider.list_batches()
+ expected_second_batch_id = all_batches.data[1].id
+ assert second_page.data[0].id == expected_second_batch_id
+
+ async def test_list_batches_invalid_after(self, provider, sample_batch_data):
+ """Test listing batches with invalid 'after' parameter."""
+ await provider.create_batch(**sample_batch_data)
+
+ response = await provider.list_batches(after="nonexistent_batch")
+
+ # Should return all batches (no filtering when 'after' batch not found)
+ assert len(response.data) == 1
+
+ async def test_kvstore_persistence(self, provider, sample_batch_data):
+ """Test that batches are properly persisted in kvstore."""
+ batch = await provider.create_batch(**sample_batch_data)
+
+ stored_data = await provider.kvstore.get(f"batch:{batch.id}")
+ assert stored_data is not None
+
+ stored_batch_dict = json.loads(stored_data)
+ assert stored_batch_dict["id"] == batch.id
+ assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
+
+ async def test_validate_input_file_not_found(self, provider):
+ """Test _validate_input when input file does not exist."""
+ provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions",
+ input_file_id="nonexistent_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) == 1
+ assert len(requests) == 0
+ assert errors[0].code == "invalid_request"
+ assert errors[0].message == "Cannot find file nonexistent_file."
+ assert errors[0].param == "input_file_id"
+ assert errors[0].line is None
+
+ async def test_validate_input_file_exists_empty_content(self, provider):
+ """Test _validate_input when file exists but is empty."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+ mock_response.body = b""
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions",
+ input_file_id="empty_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) == 0
+ assert len(requests) == 0
+
+ async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
+ """Test _validate_input when file contains valid and invalid JSON lines."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+ # Line 1: valid JSON with proper body args, Line 2: invalid JSON
+ mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions",
+ input_file_id="mixed_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ # Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
+ assert len(errors) == 1
+ assert len(requests) == 1
+
+ assert errors[0].code == "invalid_json_line"
+ assert errors[0].line == 2
+ assert errors[0].message == "This line is not parseable as valid JSON."
+
+ assert requests[0].custom_id == "req-1"
+ assert requests[0].method == "POST"
+ assert requests[0].url == "/v1/chat/completions"
+ assert requests[0].body["model"] == "test-model"
+ assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
+
+ async def test_validate_input_invalid_model(self, provider):
+ """Test _validate_input when file contains request with non-existent model."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+ mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions",
+ input_file_id="invalid_model_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) == 1
+ assert len(requests) == 0
+
+ assert errors[0].code == "model_not_found"
+ assert errors[0].line == 1
+ assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
+ assert errors[0].param == "body.model"
+
+ @pytest.mark.parametrize(
+ "param_name,param_path,error_code,error_message",
+ [
+ ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
+ ("method", "method", "missing_required_parameter", "Missing required parameter: method"),
+ ("url", "url", "missing_required_parameter", "Missing required parameter: url"),
+ ("body", "body", "missing_required_parameter", "Missing required parameter: body"),
+ ("model", "body.model", "invalid_request", "Model parameter is required"),
+ ("messages", "body.messages", "invalid_request", "Messages parameter is required"),
+ ],
+ )
+ async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
+ """Test _validate_input when file contains request with missing required parameters."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+
+ base_request = {
+ "custom_id": "req-1",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
+ }
+
+ # Remove the specific parameter being tested
+ if "." in param_path:
+ top_level, nested_param = param_path.split(".", 1)
+ del base_request[top_level][nested_param]
+ else:
+ del base_request[param_name]
+
+ mock_response.body = json.dumps(base_request).encode()
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions",
+ input_file_id=f"missing_{param_name}_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) == 1
+ assert len(requests) == 0
+
+ assert errors[0].code == error_code
+ assert errors[0].line == 1
+ assert errors[0].message == error_message
+ assert errors[0].param == param_path
+
+ async def test_validate_input_url_mismatch(self, provider):
+ """Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+ mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions", # This doesn't match the URL in the request
+ input_file_id="url_mismatch_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) == 1
+ assert len(requests) == 0
+
+ assert errors[0].code == "invalid_url"
+ assert errors[0].line == 1
+ assert errors[0].message == "URL provided for this request does not match the batch endpoint"
+ assert errors[0].param == "url"
+
+ async def test_validate_input_multiple_errors_per_request(self, provider):
+ """Test _validate_input when a single request has multiple validation errors."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+ # Request missing custom_id, has invalid URL, and missing model in body
+ mock_response.body = (
+ b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
+ )
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
+ input_file_id="multiple_errors_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) >= 2 # At least missing custom_id and URL mismatch
+ assert len(requests) == 0
+
+ for error in errors:
+ assert error.line == 1
+
+ error_codes = {error.code for error in errors}
+ assert "missing_required_parameter" in error_codes # missing custom_id
+ assert "invalid_url" in error_codes # URL mismatch
+
+ async def test_validate_input_invalid_request_format(self, provider):
+ """Test _validate_input when file contains non-object JSON (array, string, number)."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+ mock_response.body = b'["not", "a", "request", "object"]'
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions",
+ input_file_id="invalid_format_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) == 1
+ assert len(requests) == 0
+
+ assert errors[0].code == "invalid_request"
+ assert errors[0].line == 1
+ assert errors[0].message == "Each line must be a JSON dictionary object"
+
+ @pytest.mark.parametrize(
+ "param_name,param_path,invalid_value,error_message",
+ [
+ ("custom_id", "custom_id", 12345, "Custom_id must be a string"),
+ ("url", "url", 123, "URL must be a string"),
+ ("method", "method", ["POST"], "Method must be a string"),
+ ("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
+ ("model", "body.model", 123, "Model must be a string"),
+ ("messages", "body.messages", "invalid messages format", "Messages must be an array"),
+ ],
+ )
+ async def test_validate_input_invalid_parameter_types(
+ self, provider, param_name, param_path, invalid_value, error_message
+ ):
+ """Test _validate_input when file contains request with parameters that have invalid types."""
+ provider.files_api.openai_retrieve_file = AsyncMock()
+ mock_response = MagicMock()
+
+ base_request = {
+ "custom_id": "req-1",
+ "method": "POST",
+ "url": "/v1/chat/completions",
+ "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
+ }
+
+ # Override the specific parameter with invalid value
+ if "." in param_path:
+ top_level, nested_param = param_path.split(".", 1)
+ base_request[top_level][nested_param] = invalid_value
+ else:
+ base_request[param_name] = invalid_value
+
+ mock_response.body = json.dumps(base_request).encode()
+ provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
+
+ batch = BatchObject(
+ id="batch_test",
+ object="batch",
+ endpoint="/v1/chat/completions",
+ input_file_id=f"invalid_{param_name}_type_file",
+ completion_window="24h",
+ status="validating",
+ created_at=1234567890,
+ )
+
+ errors, requests = await provider._validate_input(batch)
+
+ assert len(errors) == 1
+ assert len(requests) == 0
+
+ assert errors[0].code == "invalid_request"
+ assert errors[0].line == 1
+ assert errors[0].message == error_message
+ assert errors[0].param == param_path
+
+ async def test_max_concurrent_batches(self, provider):
+ """Test max_concurrent_batches configuration and concurrency control."""
+ import asyncio
+
+ provider._batch_semaphore = asyncio.Semaphore(2)
+
+ provider.process_batches = True # enable because we're testing background processing
+
+ active_batches = 0
+
+ async def add_and_wait(batch_id: str):
+ nonlocal active_batches
+ active_batches += 1
+ await asyncio.sleep(float("inf"))
+
+ # the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
+ # so we can replace _process_batch_impl with our mock to control concurrency
+ provider._process_batch_impl = add_and_wait
+
+ for _ in range(3):
+ await provider.create_batch(
+ input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
+ )
+
+ await asyncio.sleep(0.042) # let tasks start
+
+ assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"