Revert "feat: add batches API with OpenAI compatibility" (#3149)

Reverts llamastack/llama-stack#3088

The PR broke integration tests.
This commit is contained in:
Ashwin Bharambe 2025-08-14 10:08:54 -07:00 committed by GitHub
parent de692162af
commit ee7631b6cf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 2 additions and 2707 deletions

View file

@ -14767,8 +14767,7 @@
"OpenAIFilePurpose": {
"type": "string",
"enum": [
"assistants",
"batch"
"assistants"
],
"title": "OpenAIFilePurpose",
"description": "Valid purpose values for OpenAI Files API."
@ -14845,8 +14844,7 @@
"purpose": {
"type": "string",
"enum": [
"assistants",
"batch"
"assistants"
],
"description": "The intended purpose of the file"
}

View file

@ -10951,7 +10951,6 @@ components:
type: string
enum:
- assistants
- batch
title: OpenAIFilePurpose
description: >-
Valid purpose values for OpenAI Files API.
@ -11020,7 +11019,6 @@ components:
type: string
enum:
- assistants
- batch
description: The intended purpose of the file
additionalProperties: false
required:

View file

@ -18,4 +18,3 @@ We are working on adding a few more APIs to complete the application lifecycle.
- **Batch Inference**: run inference on a dataset of inputs
- **Batch Agents**: run agents on a dataset of inputs
- **Synthetic Data Generation**: generate synthetic data for model development
- **Batches**: OpenAI-compatible batch management for inference

View file

@ -2,15 +2,6 @@
## Overview
Agents API for creating and interacting with agentic systems.
Main functionalities provided by this API:
- Create agents with specific instructions and ability to use tools.
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
- Agents can be provided with various shields (see the Safety API for more details).
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
This section contains documentation for all available providers for the **agents** API.
## Providers

View file

@ -1,21 +0,0 @@
# Batches
## Overview
Protocol for batch processing API operations.
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
Note: This API is currently under active development and may undergo changes.
This section contains documentation for all available providers for the **batches** API.
## Providers
```{toctree}
:maxdepth: 1
inline_reference
```

View file

@ -1,23 +0,0 @@
# inline::reference
## Description
Reference implementation of batches API with KVStore persistence.
## Configuration
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
| `max_concurrent_batches` | `<class 'int'>` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
| `max_concurrent_requests_per_batch` | `<class 'int'>` | No | 10 | Maximum number of concurrent requests to process per batch. |
## Sample Configuration
```yaml
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
```

View file

@ -2,8 +2,6 @@
## Overview
Llama Stack Evaluation API for running evaluations on model and agent candidates.
This section contains documentation for all available providers for the **eval** API.
## Providers

View file

@ -2,12 +2,6 @@
## Overview
Llama Stack Inference API for generating completions, chat completions, and embeddings.
This API provides the raw interface to the underlying models. Two kinds of models are supported:
- LLM models: these models generate "raw" and "chat" (conversational) completions.
- Embedding models: these models generate embeddings to be used for semantic search.
This section contains documentation for all available providers for the **inference** API.
## Providers

View file

@ -1,9 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .batches import Batches, BatchObject, ListBatchesResponse
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]

View file

@ -1,89 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type, webmethod
try:
from openai.types import Batch as BatchObject
except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
object: Literal["list"] = "list"
data: list[BatchObject] = Field(..., description="List of batch objects")
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
has_more: bool = Field(default=False, description="Whether there are more batches available")
@runtime_checkable
class Batches(Protocol):
"""Protocol for batch processing API operations.
The Batches API enables efficient processing of multiple requests in a single operation,
particularly useful for processing large datasets, batch evaluation workflows, and
cost-effective inference at scale.
Note: This API is currently under active development and may undergo changes.
"""
@webmethod(route="/openai/v1/batches", method="POST")
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
) -> BatchObject:
"""Create a new batch for processing multiple API requests.
:param input_file_id: The ID of an uploaded file containing requests for the batch.
:param endpoint: The endpoint to be used for all requests in the batch.
:param completion_window: The time window within which the batch should be processed.
:param metadata: Optional metadata for the batch.
:returns: The created batch object.
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch.
:param batch_id: The ID of the batch to retrieve.
:returns: The batch object.
"""
...
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress.
:param batch_id: The ID of the batch to cancel.
:returns: The updated batch object.
"""
...
@webmethod(route="/openai/v1/batches", method="GET")
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""List all batches for the current user.
:param after: A cursor for pagination; returns batches after this batch ID.
:param limit: Number of batches to return (default 20, max 100).
:returns: A list of batch objects.
"""
...

View file

@ -64,12 +64,6 @@ class SessionNotFoundError(ValueError):
super().__init__(message)
class ConflictError(ValueError):
"""raised when an operation cannot be performed due to a conflict with the current state"""
pass
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""

View file

@ -86,7 +86,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar inference: Text generation, chat completions, and embeddings
:cvar safety: Content moderation and safety shields
:cvar agents: Agent orchestration and execution
:cvar batches: Batch processing for asynchronous API requests
:cvar vector_io: Vector database operations and queries
:cvar datasetio: Dataset input/output operations
:cvar scoring: Model output evaluation and scoring
@ -109,7 +108,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
inference = "inference"
safety = "safety"
agents = "agents"
batches = "batches"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"

View file

@ -22,7 +22,6 @@ class OpenAIFilePurpose(StrEnum):
"""
ASSISTANTS = "assistants"
BATCH = "batch"
# TODO: Add other purposes as needed

View file

@ -8,7 +8,6 @@ import inspect
from typing import Any
from llama_stack.apis.agents import Agents
from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -76,7 +75,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,

View file

@ -32,7 +32,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError
@ -129,10 +128,6 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
]
},
)
elif isinstance(exc, ConflictError):
return HTTPException(status_code=409, detail=str(exc))
elif isinstance(exc, ResourceNotFoundError):
return HTTPException(status_code=404, detail=str(exc))
elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,36 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.core.datatypes import AccessRule, Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from .batches import ReferenceBatchesImpl
from .config import ReferenceBatchesImplConfig
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
kvstore = await kvstore_impl(config.kvstore)
inference_api: Inference | None = deps.get(Api.inference)
files_api: Files | None = deps.get(Api.files)
models_api: Models | None = deps.get(Api.models)
if inference_api is None:
raise ValueError("Inference API is required but not provided in dependencies")
if files_api is None:
raise ValueError("Files API is required but not provided in dependencies")
if models_api is None:
raise ValueError("Models API is required but not provided in dependencies")
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
await impl.initialize()
return impl

View file

@ -1,553 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import itertools
import json
import time
import uuid
from io import BytesIO
from typing import Any, Literal
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
from .config import ReferenceBatchesImplConfig
BATCH_PREFIX = "batch:"
logger = get_logger(__name__)
class AsyncBytesIO:
"""
Async-compatible BytesIO wrapper to allow async file-like operations.
We use this when uploading files to the Files API, as it expects an
async file-like object.
"""
def __init__(self, data: bytes):
self._buffer = BytesIO(data)
async def read(self, n=-1):
return self._buffer.read(n)
async def seek(self, pos, whence=0):
return self._buffer.seek(pos, whence)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._buffer.close()
def __getattr__(self, name):
return getattr(self._buffer, name)
class BatchRequest(BaseModel):
line_num: int
custom_id: str
method: str
url: str
body: dict[str, Any]
class ReferenceBatchesImpl(Batches):
"""Reference implementation of the Batches API.
This implementation processes batch files by making individual requests
to the inference API and generates output files with results.
"""
def __init__(
self,
config: ReferenceBatchesImplConfig,
inference_api: Inference,
files_api: Files,
models_api: Models,
kvstore: KVStore,
) -> None:
self.config = config
self.kvstore = kvstore
self.inference_api = inference_api
self.files_api = files_api
self.models_api = models_api
self._processing_tasks: dict[str, asyncio.Task] = {}
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
self._update_batch_lock = asyncio.Lock()
# this is to allow tests to disable background processing
self.process_batches = True
async def initialize(self) -> None:
# TODO: start background processing of existing tasks
pass
async def shutdown(self) -> None:
"""Shutdown the batches provider."""
if self._processing_tasks:
# don't cancel tasks - just let them stop naturally on shutdown
# cancelling would mark batches as "cancelled" in the database
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
Error handling by levels -
0. Input param handling, results in 40x errors before processing, e.g.
- Wrong completion_window
- Invalid metadata types
- Unknown endpoint
-> no batch created
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- input_file_id missing
- invalid json in file
- missing custom_id, method, url, body
- invalid model
- streaming
-> batch created, validation sends to failed status
2. Processing errors, result in error_file_id entries, e.g.
- Any error returned from inference endpoint
-> batch created, goes to completed status
"""
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
raise ValueError(
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
current_time = int(time.time())
batch = BatchObject(
id=batch_id,
object="batch",
endpoint=endpoint,
input_file_id=input_file_id,
completion_window=completion_window,
status="validating",
created_at=current_time,
metadata=metadata,
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
if self.process_batches:
task = asyncio.create_task(self._process_batch(batch_id))
self._processing_tasks[batch_id] = task
return batch
async def cancel_batch(self, batch_id: str) -> BatchObject:
"""Cancel a batch that is in progress."""
batch = await self.retrieve_batch(batch_id)
if batch.status in ["cancelled", "cancelling"]:
return batch
if batch.status in ["completed", "failed", "expired"]:
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
if batch_id in self._processing_tasks:
self._processing_tasks[batch_id].cancel()
# note: task removal and status="cancelled" handled in finally block of _process_batch
return await self.retrieve_batch(batch_id)
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
"""
List all batches, eventually only for the current user.
With no notion of user, we return all batches.
"""
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
batches = []
for batch_data in batch_values:
if batch_data:
batches.append(BatchObject.model_validate_json(batch_data))
batches.sort(key=lambda b: b.created_at, reverse=True)
start_idx = 0
if after:
for i, batch in enumerate(batches):
if batch.id == after:
start_idx = i + 1
break
page_batches = batches[start_idx : start_idx + limit]
has_more = (start_idx + limit) < len(batches)
first_id = page_batches[0].id if page_batches else None
last_id = page_batches[-1].id if page_batches else None
return ListBatchesResponse(
data=page_batches,
first_id=first_id,
last_id=last_id,
has_more=has_more,
)
async def retrieve_batch(self, batch_id: str) -> BatchObject:
"""Retrieve information about a specific batch."""
batch_data = await self.kvstore.get(f"batch:{batch_id}")
if not batch_data:
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
return BatchObject.model_validate_json(batch_data)
async def _update_batch(self, batch_id: str, **updates) -> None:
"""Update batch fields in kvstore."""
async with self._update_batch_lock:
try:
batch = await self.retrieve_batch(batch_id)
# batch processing is async. once cancelling, only allow "cancelled" status updates
if batch.status == "cancelling" and updates.get("status") != "cancelled":
logger.info(
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
)
return
if "errors" in updates:
updates["errors"] = updates["errors"].model_dump()
batch_dict = batch.model_dump()
batch_dict.update(updates)
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
except Exception as e:
logger.error(f"Failed to update batch {batch_id}: {e}")
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
"""
Read & validate input, return errors and valid input.
Validation of
- input_file_id existance
- valid json
- custom_id, method, url, body presence and valid
- no streaming
"""
requests: list[BatchRequest] = []
errors: list[BatchError] = []
try:
await self.files_api.openai_retrieve_file(batch.input_file_id)
except Exception:
errors.append(
BatchError(
code="invalid_request",
line=None,
message=f"Cannot find file {batch.input_file_id}.",
param="input_file_id",
)
)
return errors, requests
# TODO(SECURITY): do something about large files
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
file_content = file_content_response.body.decode("utf-8")
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
if line.strip(): # skip empty lines
try:
request = json.loads(line)
if not isinstance(request, dict):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message="Each line must be a JSON dictionary object",
)
)
continue
valid = True
for param, expected_type, type_string in [
("custom_id", str, "string"),
("method", str, "string"),
("url", str, "string"),
("body", dict, "JSON dictionary object"),
]:
if param not in request:
errors.append(
BatchError(
code="missing_required_parameter",
line=line_num,
message=f"Missing required parameter: {param}",
param=param,
)
)
valid = False
elif not isinstance(request[param], expected_type):
param_name = "URL" if param == "url" else param.capitalize()
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param_name} must be a {type_string}",
param=param,
)
)
valid = False
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
errors.append(
BatchError(
code="invalid_url",
line=line_num,
message="URL provided for this request does not match the batch endpoint",
param="url",
)
)
valid = False
if (body := request.get("body")) and isinstance(body, dict):
if body.get("stream", False):
errors.append(
BatchError(
code="streaming_unsupported",
line=line_num,
message="Streaming is not supported in batch processing",
param="body.stream",
)
)
valid = False
for param, expected_type, type_string in [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]:
if param not in body:
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} parameter is required",
param=f"body.{param}",
)
)
valid = False
elif not isinstance(body[param], expected_type):
errors.append(
BatchError(
code="invalid_request",
line=line_num,
message=f"{param.capitalize()} must be {type_string}",
param=f"body.{param}",
)
)
valid = False
if "model" in body and isinstance(body["model"], str):
try:
await self.models_api.get_model(body["model"])
except Exception:
errors.append(
BatchError(
code="model_not_found",
line=line_num,
message=f"Model '{body['model']}' does not exist or is not supported",
param="body.model",
)
)
valid = False
if valid:
assert isinstance(url, str), "URL must be a string" # for mypy
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
requests.append(
BatchRequest(
line_num=line_num,
url=url,
method=request["method"],
custom_id=request["custom_id"],
body=body,
),
)
except json.JSONDecodeError:
errors.append(
BatchError(
code="invalid_json_line",
line=line_num,
message="This line is not parseable as valid JSON.",
)
)
return errors, requests
async def _process_batch(self, batch_id: str) -> None:
"""Background task to process a batch of requests."""
try:
logger.info(f"Starting batch processing for {batch_id}")
async with self._batch_semaphore: # semaphore to limit concurrency
logger.info(f"Acquired semaphore for batch {batch_id}")
await self._process_batch_impl(batch_id)
except asyncio.CancelledError:
logger.info(f"Batch processing cancelled for {batch_id}")
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
except Exception as e:
logger.error(f"Batch processing failed for {batch_id}: {e}")
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
)
finally:
self._processing_tasks.pop(batch_id, None)
async def _process_batch_impl(self, batch_id: str) -> None:
"""Implementation of batch processing logic."""
errors: list[BatchError] = []
batch = await self.retrieve_batch(batch_id)
errors, requests = await self._validate_input(batch)
if errors:
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
return
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
total_requests = len(requests)
await self._update_batch(
batch_id,
status="in_progress",
request_counts={"total": total_requests, "completed": 0, "failed": 0},
)
error_results = []
success_results = []
completed_count = 0
failed_count = 0
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
async with asyncio.TaskGroup() as tg:
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
for result in chunk_results:
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
failed_count += 1
error_results.append(result)
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
completed_count += 1
success_results.append(result)
else: # unexpected result
failed_count += 1
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
await self._update_batch(
batch_id,
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
)
if errors:
await self._update_batch(
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
)
return
try:
output_file_id = await self._create_output_file(batch_id, success_results, "success")
await self._update_batch(batch_id, output_file_id=output_file_id)
error_file_id = await self._create_output_file(batch_id, error_results, "error")
await self._update_batch(batch_id, error_file_id=error_file_id)
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
logger.info(
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
)
except Exception as e:
# note: errors is empty at this point, so we don't lose anything by ignoring it
await self._update_batch(
batch_id,
status="failed",
failed_at=int(time.time()),
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
)
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
"""Process a single request from the batch."""
request_id = f"batch_req_{batch_id}_{request.line_num}"
try:
# TODO(SECURITY): review body for security issues
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {
"id": request_id,
"custom_id": request.custom_id,
"error": {"type": "request_failed", "message": str(e)},
}
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
"""
Create an output file with batch results.
This function filters results based on the specified file_type
and uploads the file to the Files API.
"""
output_lines = [json.dumps(result) for result in results]
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
return uploaded_file.id

View file

@ -1,40 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
class ReferenceBatchesImplConfig(BaseModel):
"""Configuration for the Reference Batches implementation."""
kvstore: KVStoreConfig = Field(
description="Configuration for the key-value store backend.",
)
max_concurrent_batches: int = Field(
default=1,
description="Maximum number of concurrent batches to process simultaneously.",
ge=1,
)
max_concurrent_requests_per_batch: int = Field(
default=10,
description="Maximum number of concurrent requests to process per batch.",
ge=1,
)
# TODO: add a max requests per second rate limiter
@classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict:
return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="batches.db",
),
}

View file

@ -1,26 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.batches,
provider_type="inline::reference",
pip_packages=["openai"],
module="llama_stack.providers.inline.batches.reference",
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
api_dependencies=[
Api.inference,
Api.files,
Api.models,
],
description="Reference implementation of batches API with KVStore persistence.",
),
]

View file

@ -18,23 +18,6 @@ from llama_stack.core.distribution import get_provider_registry
REPO_ROOT = Path(__file__).parent.parent
def get_api_docstring(api_name: str) -> str | None:
"""Extract docstring from the API protocol class."""
try:
# Import the API module dynamically
api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()])
# Get the main protocol class (usually capitalized API name)
protocol_class_name = api_name.title()
if hasattr(api_module, protocol_class_name):
protocol_class = getattr(api_module, protocol_class_name)
return protocol_class.__doc__
except (ImportError, AttributeError):
pass
return None
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
@ -278,11 +261,6 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
index_content.append(f"# {api_name.title()}\n")
index_content.append("## Overview\n")
api_docstring = get_api_docstring(api_name)
if api_docstring:
cleaned_docstring = api_docstring.strip()
index_content.append(f"{cleaned_docstring}\n")
index_content.append(
f"This section contains documentation for all available providers for the **{api_name}** API.\n"
)

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,122 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Shared pytest fixtures for batch tests."""
import json
import time
import warnings
from contextlib import contextmanager
from io import BytesIO
import pytest
from llama_stack.apis.files import OpenAIFilePurpose
class BatchHelper:
"""Helper class for creating and managing batch input files."""
def __init__(self, client):
"""Initialize with either a batch_client or openai_client."""
self.client = client
@contextmanager
def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
"""Context manager for creating and cleaning up batch input files.
Args:
content: Either a list of batch request dictionaries or raw string content
filename_prefix: Prefix for the generated filename (or full filename if content is string)
Yields:
The uploaded file object
"""
if isinstance(content, str):
# Handle raw string content (e.g., malformed JSONL, empty files)
file_content = content.encode("utf-8")
else:
# Handle list of batch request dictionaries
jsonl_content = "\n".join(json.dumps(req) for req in content)
file_content = jsonl_content.encode("utf-8")
filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
with BytesIO(file_content) as file_buffer:
file_buffer.name = filename
uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
try:
yield uploaded_file
finally:
try:
self.client.files.delete(uploaded_file.id)
except Exception:
warnings.warn(
f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
stacklevel=2,
)
def wait_for(
self,
batch_id: str,
max_wait_time: int = 60,
sleep_interval: int | None = None,
expected_statuses: set[str] | None = None,
timeout_action: str = "fail",
):
"""Wait for a batch to reach a terminal status.
Args:
batch_id: The batch ID to monitor
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
expected_statuses: Set of expected terminal statuses (default: {"completed"})
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
Returns:
The final batch object
Raises:
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
"""
if sleep_interval is None:
# Default to 1/10th of max_wait_time, with min 1s and max 15s
sleep_interval = max(1, min(15, max_wait_time // 10))
if expected_statuses is None:
expected_statuses = {"completed"}
terminal_statuses = {"completed", "failed", "cancelled", "expired"}
unexpected_statuses = terminal_statuses - expected_statuses
start_time = time.time()
while time.time() - start_time < max_wait_time:
current_batch = self.client.batches.retrieve(batch_id)
if current_batch.status in expected_statuses:
return current_batch
elif current_batch.status in unexpected_statuses:
error_msg = f"Batch reached unexpected status: {current_batch.status}"
if timeout_action == "skip":
pytest.skip(error_msg)
else:
pytest.fail(error_msg)
time.sleep(sleep_interval)
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
if timeout_action == "skip":
pytest.skip(timeout_msg)
else:
pytest.fail(timeout_msg)
@pytest.fixture
def batch_helper(openai_client):
"""Fixture that provides a BatchHelper instance for OpenAI client."""
return BatchHelper(openai_client)

View file

@ -1,270 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Integration tests for the Llama Stack batch processing functionality.
This module contains comprehensive integration tests for the batch processing API,
using the OpenAI-compatible client interface for consistency.
Test Categories:
1. Core Batch Operations:
- test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval
- test_batch_listing: Basic batch listing functionality
- test_batch_immediate_cancellation: Batch cancellation workflow
# TODO: cancel during processing
2. End-to-End Processing:
- test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation
Note: Error conditions and edge cases are primarily tested in test_batches_errors.py
for better organization and separation of concerns.
CLEANUP WARNING: These tests currently create batches that are not automatically
cleaned up after test completion. This may lead to resource accumulation over
multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch.
The test_batch_e2e_chat_completions test does clean up its output and error files.
"""
import json
class TestBatchesIntegration:
"""Integration tests for the batches API."""
def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id):
"""Test comprehensive batch creation and retrieval scenarios."""
test_metadata = {
"test_type": "comprehensive",
"purpose": "creation_and_retrieval_test",
"version": "1.0",
"tags": "test,batch",
}
batch_requests = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata=test_metadata,
)
assert batch.endpoint == "/v1/chat/completions"
assert batch.input_file_id == uploaded_file.id
assert batch.completion_window == "24h"
assert batch.metadata == test_metadata
retrieved_batch = openai_client.batches.retrieve(batch.id)
assert retrieved_batch.id == batch.id
assert retrieved_batch.object == batch.object
assert retrieved_batch.endpoint == batch.endpoint
assert retrieved_batch.input_file_id == batch.input_file_id
assert retrieved_batch.completion_window == batch.completion_window
assert retrieved_batch.metadata == batch.metadata
def test_batch_listing(self, openai_client, batch_helper, text_model_id):
"""
Test batch listing.
This test creates multiple batches and verifies that they can be listed.
It also deletes the input files before execution, which means the batches
will appear as failed due to missing input files. This is expected and
a good thing, because it means no inference is performed.
"""
batch_ids = []
for i in range(2):
batch_requests = [
{
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": f"Hello {i}"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
batch_ids.append(batch.id)
batch_list = openai_client.batches.list()
assert isinstance(batch_list.data, list)
listed_batch_ids = {b.id for b in batch_list.data}
for batch_id in batch_ids:
assert batch_id in listed_batch_ids
def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id):
"""Test immediate batch cancellation."""
batch_requests = [
{
"custom_id": "request-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
# hopefully cancel the batch before it completes
cancelling_batch = openai_client.batches.cancel(batch.id)
assert cancelling_batch.status in ["cancelling", "cancelled"]
assert isinstance(cancelling_batch.cancelling_at, int), (
f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}"
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min
expected_statuses={"cancelled"},
timeout_action="skip",
)
assert final_batch.status == "cancelled"
assert isinstance(final_batch.cancelled_at, int), (
f"cancelled_at should be int, got {type(final_batch.cancelled_at)}"
)
def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id):
"""Test end-to-end batch processing for chat completions with both successful and failed operations."""
batch_requests = [
{
"custom_id": "success-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Say hello"}],
"max_tokens": 20,
},
},
{
"custom_id": "error-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "This should fail"}],
"max_tokens": -1, # Invalid negative max_tokens will cause inference error
},
},
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"test": "e2e_success_and_errors_test"},
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60, # often takes 2-3 minutes
expected_statuses={"completed"},
timeout_action="skip",
)
# Expecting a completed batch with both successful and failed requests
# Batch(id='batch_xxx',
# completion_window='24h',
# created_at=...,
# endpoint='/v1/chat/completions',
# input_file_id='file-xxx',
# object='batch',
# status='completed',
# output_file_id='file-xxx',
# error_file_id='file-xxx',
# request_counts=BatchRequestCounts(completed=1, failed=1, total=2))
assert final_batch.status == "completed"
assert final_batch.request_counts is not None
assert final_batch.request_counts.total == 2
assert final_batch.request_counts.completed == 1
assert final_batch.request_counts.failed == 1
assert final_batch.output_file_id is not None, "Output file should exist for successful requests"
output_content = openai_client.files.content(final_batch.output_file_id)
if isinstance(output_content, str):
output_text = output_content
else:
output_text = output_content.content.decode("utf-8")
output_lines = output_text.strip().split("\n")
for line in output_lines:
result = json.loads(line)
assert "id" in result
assert "custom_id" in result
assert result["custom_id"] == "success-1"
assert "response" in result
assert result["response"]["status_code"] == 200
assert "body" in result["response"]
assert "choices" in result["response"]["body"]
assert final_batch.error_file_id is not None, "Error file should exist for failed requests"
error_content = openai_client.files.content(final_batch.error_file_id)
if isinstance(error_content, str):
error_text = error_content
else:
error_text = error_content.content.decode("utf-8")
error_lines = error_text.strip().split("\n")
for line in error_lines:
result = json.loads(line)
assert "id" in result
assert "custom_id" in result
assert result["custom_id"] == "error-1"
assert "error" in result
error = result["error"]
assert error is not None
assert "code" in error or "message" in error, "Error should have code or message"
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully"
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"

View file

@ -1,693 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Error handling and edge case tests for the Llama Stack batch processing functionality.
This module focuses exclusively on testing error conditions, validation failures,
and edge cases for batch operations to ensure robust error handling and graceful
degradation.
Test Categories:
1. File and Input Validation:
- test_batch_nonexistent_file_id: Handling invalid file IDs
- test_batch_malformed_jsonl: Processing malformed JSONL input files
- test_file_malformed_batch_file: Handling malformed files at upload time
- test_batch_missing_required_fields: Validation of required request fields
2. API Endpoint and Model Validation:
- test_batch_invalid_endpoint: Invalid endpoint handling during creation
- test_batch_error_handling_invalid_model: Error handling with nonexistent models
- test_batch_endpoint_mismatch: Validation of endpoint/URL consistency
3. Batch Lifecycle Error Handling:
- test_batch_retrieve_nonexistent: Retrieving non-existent batches
- test_batch_cancel_nonexistent: Cancelling non-existent batches
- test_batch_cancel_completed: Attempting to cancel completed batches
4. Parameter and Configuration Validation:
- test_batch_invalid_completion_window: Invalid completion window values
- test_batch_invalid_metadata_types: Invalid metadata type validation
- test_batch_missing_required_body_fields: Validation of required fields in request body
5. Feature Restriction and Compatibility:
- test_batch_streaming_not_supported: Streaming request rejection
- test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation
Note: Core functionality and OpenAI compatibility tests are located in
test_batches_integration.py for better organization and separation of concerns.
CLEANUP WARNING: These tests create batches to test error conditions but do not
automatically clean them up after test completion. While most error tests create
batches that fail quickly, some may create valid batches that consume resources.
"""
import pytest
from openai import BadRequestError, ConflictError, NotFoundError
class TestBatchesErrorHandling:
"""Error handling and edge case tests for the batches API using OpenAI client."""
def test_batch_nonexistent_file_id(self, openai_client, batch_helper):
"""Test batch creation with nonexistent input file ID."""
batch = openai_client.batches.create(
input_file_id="file-nonexistent-xyz",
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='invalid_request',
# line=None,
# message='Cannot find file ..., or organization ... does not have access to it.',
# param='file_id')
# ], object='list'),
# failed_at=1754566971,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.code == "invalid_request"
assert "cannot find file" in error.message.lower()
def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with invalid endpoint."""
batch_requests = [
{
"custom_id": "invalid-endpoint",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
with pytest.raises(BadRequestError) as exc_info:
openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/invalid/endpoint",
completion_window="24h",
)
# Expected -
# Error code: 400 - {
# 'error': {
# 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.",
# 'type': 'invalid_request_error',
# 'param': 'endpoint',
# 'code': 'invalid_value'
# }
# }
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 400
assert "invalid value" in error_msg
assert "/v1/invalid/endpoint" in error_msg
assert "supported values" in error_msg
assert "endpoint" in error_msg
assert "invalid_value" in error_msg
def test_batch_malformed_jsonl(self, openai_client, batch_helper):
"""
Test batch with malformed JSONL input.
The /v1/files endpoint requires valid JSONL format, so we provide a well formed line
before a malformed line to ensure we get to the /v1/batches validation stage.
"""
with batch_helper.create_file(
"""{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}}
{invalid json here""",
"malformed_batch_input.jsonl",
) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# ...,
# BatchError(code='invalid_json_line',
# line=2,
# message='This line is not parseable as valid JSON.',
# param=None)
# ], object='list'),
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) > 0
error = final_batch.errors.data[-1] # get last error because first may be about the "test" model
assert error.code == "invalid_json_line"
assert error.line == 2
assert "not" in error.message.lower()
assert "valid json" in error.message.lower()
@pytest.mark.xfail(reason="Not all file providers validate content")
@pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"])
def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests):
"""Test file upload with malformed content."""
with pytest.raises(BadRequestError) as exc_info:
with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"):
# /v1/files rejects the file, we don't get to batch creation
pass
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 400
assert "invalid file format" in error_msg
assert "jsonl" in error_msg
def test_batch_retrieve_nonexistent(self, openai_client):
"""Test retrieving nonexistent batch."""
with pytest.raises(NotFoundError) as exc_info:
openai_client.batches.retrieve("batch-nonexistent-xyz")
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 404
assert "no batch found" in error_msg or "not found" in error_msg
def test_batch_cancel_nonexistent(self, openai_client):
"""Test cancelling nonexistent batch."""
with pytest.raises(NotFoundError) as exc_info:
openai_client.batches.cancel("batch-nonexistent-xyz")
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 404
assert "no batch found" in error_msg or "not found" in error_msg
def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id):
"""Test cancelling already completed batch."""
batch_requests = [
{
"custom_id": "cancel-completed",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Quick test"}],
"max_tokens": 5,
},
}
]
with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60, # often take 10-11 min, give it 3 min
expected_statuses={"completed"},
timeout_action="skip",
)
deleted_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully"
with pytest.raises(ConflictError) as exc_info:
openai_client.batches.cancel(batch.id)
# Expecting -
# Error code: 409 - {
# 'error': {
# 'message': "Cannot cancel a batch with status 'completed'.",
# 'type': 'invalid_request_error',
# 'param': None,
# 'code': None
# }
# }
#
# NOTE: Same for "failed", cancelling "cancelled" batches is allowed
error_msg = str(exc_info.value).lower()
assert exc_info.value.status_code == 409
assert "cannot cancel" in error_msg
def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id):
"""Test batch with requests missing required fields."""
batch_requests = [
{
# Missing custom_id
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "No custom_id"}],
"max_tokens": 10,
},
},
{
"custom_id": "no-method",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "No method"}],
"max_tokens": 10,
},
},
{
"custom_id": "no-url",
"method": "POST",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "No URL"}],
"max_tokens": 10,
},
},
{
"custom_id": "no-body",
"method": "POST",
"url": "/v1/chat/completions",
},
]
with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(
# data=[
# BatchError(
# code='missing_required_parameter',
# line=1,
# message="Missing required parameter: 'custom_id'.",
# param='custom_id'
# ),
# BatchError(
# code='missing_required_parameter',
# line=2,
# message="Missing required parameter: 'method'.",
# param='method'
# ),
# BatchError(
# code='missing_required_parameter',
# line=3,
# message="Missing required parameter: 'url'.",
# param='url'
# ),
# BatchError(
# code='missing_required_parameter',
# line=4,
# message="Missing required parameter: 'body'.",
# param='body'
# )
# ], object='list'),
# failed_at=1754566945,
# ...)
# )
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 4
no_custom_id_error = final_batch.errors.data[0]
assert no_custom_id_error.code == "missing_required_parameter"
assert no_custom_id_error.line == 1
assert "missing" in no_custom_id_error.message.lower()
assert "custom_id" in no_custom_id_error.message.lower()
no_method_error = final_batch.errors.data[1]
assert no_method_error.code == "missing_required_parameter"
assert no_method_error.line == 2
assert "missing" in no_method_error.message.lower()
assert "method" in no_method_error.message.lower()
no_url_error = final_batch.errors.data[2]
assert no_url_error.code == "missing_required_parameter"
assert no_url_error.line == 3
assert "missing" in no_url_error.message.lower()
assert "url" in no_url_error.message.lower()
no_body_error = final_batch.errors.data[3]
assert no_body_error.code == "missing_required_parameter"
assert no_body_error.line == 4
assert "missing" in no_body_error.message.lower()
assert "body" in no_body_error.message.lower()
def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with invalid completion window."""
batch_requests = [
{
"custom_id": "invalid-completion-window",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
for window in ["1h", "48h", "invalid", ""]:
with pytest.raises(BadRequestError) as exc_info:
openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window=window,
)
assert exc_info.value.status_code == 400
error_msg = str(exc_info.value).lower()
assert "error" in error_msg
assert "completion_window" in error_msg
def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id):
"""Test that streaming responses are not supported in batches."""
batch_requests = [
{
"custom_id": "streaming-test",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
"stream": True, # Not supported
},
}
]
with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(code='streaming_unsupported',
# line=1,
# message='Chat Completions: Streaming is not supported in the Batch API.',
# param='body.stream')
# ], object='list'),
# failed_at=1754566965,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.code == "streaming_unsupported"
assert error.line == 1
assert "streaming" in error.message.lower()
assert "not supported" in error.message.lower()
assert error.param == "body.stream"
assert final_batch.failed_at is not None
def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id):
"""
Test batch with mixed streaming and non-streaming requests.
This is distinct from test_batch_streaming_not_supported, which tests a single
streaming request, to ensure an otherwise valid batch fails when a single
streaming request is included.
"""
batch_requests = [
{
"custom_id": "valid-non-streaming-request",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello without streaming"}],
"max_tokens": 10,
},
},
{
"custom_id": "streaming-request",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello with streaming"}],
"max_tokens": 10,
"stream": True, # Not supported
},
},
]
with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='streaming_unsupported',
# line=2,
# message='Chat Completions: Streaming is not supported in the Batch API.',
# param='body.stream')
# ], object='list'),
# failed_at=1754574442,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.code == "streaming_unsupported"
assert error.line == 2
assert "streaming" in error.message.lower()
assert "not supported" in error.message.lower()
assert error.param == "body.stream"
assert final_batch.failed_at is not None
def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with mismatched endpoint and request URL."""
batch_requests = [
{
"custom_id": "endpoint-mismatch",
"method": "POST",
"url": "/v1/embeddings", # Different from batch endpoint
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
},
}
]
with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions", # Different from request URL
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='invalid_url',
# line=1,
# message='The URL provided for this request does not match the batch endpoint.',
# param='url')
# ], object='list'),
# failed_at=1754566972,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.line == 1
assert error.code == "invalid_url"
assert "does not match" in error.message.lower()
assert "endpoint" in error.message.lower()
assert final_batch.failed_at is not None
def test_batch_error_handling_invalid_model(self, openai_client, batch_helper):
"""Test batch error handling with invalid model."""
batch_requests = [
{
"custom_id": "invalid-model",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "nonexistent-model-xyz",
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(code='model_not_found',
# line=1,
# message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.",
# param='body.model')
# ], object='list'),
# failed_at=1754566978,
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 1
error = final_batch.errors.data[0]
assert error.line == 1
assert error.code == "model_not_found"
assert "not supported" in error.message.lower()
assert error.param == "body.model"
assert final_batch.failed_at is not None
def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id):
"""Test batch with requests missing required fields in body (model and messages)."""
batch_requests = [
{
"custom_id": "missing-model",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
# Missing model field
"messages": [{"role": "user", "content": "Hello without model"}],
"max_tokens": 10,
},
},
{
"custom_id": "missing-messages",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
# Missing messages field
"max_tokens": 10,
},
},
]
with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
# Expecting -
# Batch(...,
# status='failed',
# errors=Errors(data=[
# BatchError(
# code='invalid_request',
# line=1,
# message='Model parameter is required.',
# param='body.model'),
# BatchError(
# code='invalid_request',
# line=2,
# message='Messages parameter is required.',
# param='body.messages')
# ], object='list'),
# ...)
assert final_batch.status == "failed"
assert final_batch.errors is not None
assert len(final_batch.errors.data) == 2
model_error = final_batch.errors.data[0]
assert model_error.line == 1
assert "model" in model_error.message.lower()
assert model_error.param == "body.model"
messages_error = final_batch.errors.data[1]
assert messages_error.line == 2
assert "messages" in messages_error.message.lower()
assert messages_error.param == "body.messages"
assert final_batch.failed_at is not None
def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id):
"""Test batch creation with invalid metadata types (like lists)."""
batch_requests = [
{
"custom_id": "invalid-metadata-type",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": text_model_id,
"messages": [{"role": "user", "content": "Hello"}],
"max_tokens": 10,
},
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
with pytest.raises(Exception) as exc_info:
openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={
"tags": ["tag1", "tag2"], # Invalid type, should be a string
},
)
# Expecting -
# Error code: 400 - {'error':
# {'message': "Invalid type for 'metadata.tags': expected a string,
# but got an array instead.",
# 'type': 'invalid_request_error', 'param': 'metadata.tags',
# 'code': 'invalid_type'}}
error_msg = str(exc_info.value).lower()
assert "400" in error_msg
assert "tags" in error_msg
assert "string" in error_msg

View file

@ -1,753 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test suite for the reference implementation of the Batches API.
The tests are categorized and outlined below, keep this updated:
- Batch creation with various parameters and validation:
* test_create_and_retrieve_batch_success (positive)
* test_create_batch_without_metadata (positive)
* test_create_batch_completion_window (negative)
* test_create_batch_invalid_endpoints (negative)
* test_create_batch_invalid_metadata (negative)
- Batch retrieval and error handling for non-existent batches:
* test_retrieve_batch_not_found (negative)
- Batch cancellation with proper status transitions:
* test_cancel_batch_success (positive)
* test_cancel_batch_invalid_statuses (negative)
* test_cancel_batch_not_found (negative)
- Batch listing with pagination and filtering:
* test_list_batches_empty (positive)
* test_list_batches_single_batch (positive)
* test_list_batches_multiple_batches (positive)
* test_list_batches_with_limit (positive)
* test_list_batches_with_pagination (positive)
* test_list_batches_invalid_after (negative)
- Data persistence in the underlying key-value store:
* test_kvstore_persistence (positive)
- Batch processing concurrency control:
* test_max_concurrent_batches (positive)
- Input validation testing (direct _validate_input method tests):
* test_validate_input_file_not_found (negative)
* test_validate_input_file_exists_empty_content (positive)
* test_validate_input_file_mixed_valid_invalid_json (mixed)
* test_validate_input_invalid_model (negative)
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
dependencies like inference, files, and models APIs.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestReferenceBatchesImpl:
"""Test the reference implementation of the Batches API."""
@pytest.fixture
async def provider(self):
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
# Create kvstore and mock APIs
from unittest.mock import AsyncMock
from llama_stack.providers.utils.kvstore import kvstore_impl
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()
# unit tests should not require background processing
provider.process_batches = False
yield provider
await provider.shutdown()
@pytest.fixture
def sample_batch_data(self):
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}
def _validate_batch_type(self, batch, expected_metadata=None):
"""
Helper function to validate batch object structure and field types.
Note: This validates the direct BatchObject from the provider, not the
client library response which has a different structure.
Args:
batch: The BatchObject instance to validate.
expected_metadata: Optional expected metadata dictionary to validate against.
"""
assert isinstance(batch.id, str)
assert isinstance(batch.completion_window, str)
assert isinstance(batch.created_at, int)
assert isinstance(batch.endpoint, str)
assert isinstance(batch.input_file_id, str)
assert batch.object == "batch"
assert batch.status in [
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
]
if expected_metadata is not None:
assert batch.metadata == expected_metadata
timestamp_fields = [
"cancelled_at",
"cancelling_at",
"completed_at",
"expired_at",
"expires_at",
"failed_at",
"finalizing_at",
"in_progress_at",
]
for field in timestamp_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
file_id_fields = ["error_file_id", "output_file_id"]
for field in file_id_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
if hasattr(batch, "request_counts") and batch.request_counts is not None:
assert isinstance(batch.request_counts.completed, int), (
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
)
assert isinstance(batch.request_counts.failed, int), (
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
)
assert isinstance(batch.request_counts.total, int), (
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
)
if hasattr(batch, "errors") and batch.errors is not None:
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
if hasattr(batch.errors, "data") and batch.errors.data is not None:
assert isinstance(batch.errors.data, list), (
f"errors.data should be list or None, got {type(batch.errors.data)}"
)
for i, error_item in enumerate(batch.errors.data):
assert isinstance(error_item, dict), (
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
)
if hasattr(error_item, "code") and error_item.code is not None:
assert isinstance(error_item.code, str), (
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
)
if hasattr(error_item, "line") and error_item.line is not None:
assert isinstance(error_item.line, int), (
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
)
if hasattr(error_item, "message") and error_item.message is not None:
assert isinstance(error_item.message, str), (
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
)
if hasattr(error_item, "param") and error_item.param is not None:
assert isinstance(error_item.param, str), (
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
)
if hasattr(batch.errors, "object") and batch.errors.object is not None:
assert isinstance(batch.errors.object, str), (
f"errors.object should be str or None, got {type(batch.errors.object)}"
)
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
"""Test successful batch creation and retrieval."""
created_batch = await provider.create_batch(**sample_batch_data)
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
assert created_batch.id.startswith("batch_")
assert len(created_batch.id) > 13
assert created_batch.object == "batch"
assert created_batch.endpoint == sample_batch_data["endpoint"]
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
assert created_batch.completion_window == sample_batch_data["completion_window"]
assert created_batch.status == "validating"
assert created_batch.metadata == sample_batch_data["metadata"]
assert isinstance(created_batch.created_at, int)
assert created_batch.created_at > 0
retrieved_batch = await provider.retrieve_batch(created_batch.id)
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
assert retrieved_batch.id == created_batch.id
assert retrieved_batch.input_file_id == created_batch.input_file_id
assert retrieved_batch.endpoint == created_batch.endpoint
assert retrieved_batch.status == created_batch.status
assert retrieved_batch.metadata == created_batch.metadata
async def test_create_batch_without_metadata(self, provider):
"""Test batch creation without optional metadata."""
batch = await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
)
assert batch.metadata is None
async def test_create_batch_completion_window(self, provider):
"""Test batch creation with invalid completion window."""
with pytest.raises(ValueError, match="Invalid completion_window"):
await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
)
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
)
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
"""Test batch creation with various invalid endpoints."""
with pytest.raises(ValueError, match="Invalid endpoint"):
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
async def test_create_batch_invalid_metadata(self, provider):
"""Test that batch creation fails with invalid metadata."""
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={123: "invalid_key"}, # Non-string key
)
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"valid_key": 456}, # Non-string value
)
async def test_retrieve_batch_not_found(self, provider):
"""Test error when retrieving non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.retrieve_batch("nonexistent_batch")
async def test_cancel_batch_success(self, provider, sample_batch_data):
"""Test successful batch cancellation."""
created_batch = await provider.create_batch(**sample_batch_data)
assert created_batch.status == "validating"
cancelled_batch = await provider.cancel_batch(created_batch.id)
assert cancelled_batch.id == created_batch.id
assert cancelled_batch.status in ["cancelling", "cancelled"]
assert isinstance(cancelled_batch.cancelling_at, int)
assert cancelled_batch.cancelling_at >= created_batch.created_at
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
"""Test error when cancelling batch in final states."""
provider.process_batches = False
created_batch = await provider.create_batch(**sample_batch_data)
# directly update status in kvstore
await provider._update_batch(created_batch.id, status=status)
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
await provider.cancel_batch(created_batch.id)
async def test_cancel_batch_not_found(self, provider):
"""Test error when cancelling non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.cancel_batch("nonexistent_batch")
async def test_list_batches_empty(self, provider):
"""Test listing batches when none exist."""
response = await provider.list_batches()
assert response.object == "list"
assert response.data == []
assert response.first_id is None
assert response.last_id is None
assert response.has_more is False
async def test_list_batches_single_batch(self, provider, sample_batch_data):
"""Test listing batches with single batch."""
created_batch = await provider.create_batch(**sample_batch_data)
response = await provider.list_batches()
assert len(response.data) == 1
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
assert response.data[0].id == created_batch.id
assert response.first_id == created_batch.id
assert response.last_id == created_batch.id
assert response.has_more is False
async def test_list_batches_multiple_batches(self, provider):
"""Test listing multiple batches."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches()
assert len(response.data) == 3
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids == expected_ids
assert response.has_more is False
assert response.first_id in expected_ids
assert response.last_id in expected_ids
async def test_list_batches_with_limit(self, provider):
"""Test listing batches with limit parameter."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches(limit=2)
assert len(response.data) == 2
assert response.has_more is True
assert response.first_id == response.data[0].id
assert response.last_id == response.data[1].id
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids.issubset(expected_ids)
async def test_list_batches_with_pagination(self, provider):
"""Test listing batches with pagination using 'after' parameter."""
for i in range(3):
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
# Get first page
first_page = await provider.list_batches(limit=1)
assert len(first_page.data) == 1
assert first_page.has_more is True
# Get second page using 'after'
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
assert len(second_page.data) == 1
assert second_page.data[0].id != first_page.data[0].id
# Verify we got the next batch in order
all_batches = await provider.list_batches()
expected_second_batch_id = all_batches.data[1].id
assert second_page.data[0].id == expected_second_batch_id
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
"""Test listing batches with invalid 'after' parameter."""
await provider.create_batch(**sample_batch_data)
response = await provider.list_batches(after="nonexistent_batch")
# Should return all batches (no filtering when 'after' batch not found)
assert len(response.data) == 1
async def test_kvstore_persistence(self, provider, sample_batch_data):
"""Test that batches are properly persisted in kvstore."""
batch = await provider.create_batch(**sample_batch_data)
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
assert stored_data is not None
stored_batch_dict = json.loads(stored_data)
assert stored_batch_dict["id"] == batch.id
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
async def test_validate_input_file_not_found(self, provider):
"""Test _validate_input when input file does not exist."""
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="nonexistent_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].message == "Cannot find file nonexistent_file."
assert errors[0].param == "input_file_id"
assert errors[0].line is None
async def test_validate_input_file_exists_empty_content(self, provider):
"""Test _validate_input when file exists but is empty."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b""
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="empty_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 0
assert len(requests) == 0
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
"""Test _validate_input when file contains valid and invalid JSON lines."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="mixed_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
assert len(errors) == 1
assert len(requests) == 1
assert errors[0].code == "invalid_json_line"
assert errors[0].line == 2
assert errors[0].message == "This line is not parseable as valid JSON."
assert requests[0].custom_id == "req-1"
assert requests[0].method == "POST"
assert requests[0].url == "/v1/chat/completions"
assert requests[0].body["model"] == "test-model"
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
async def test_validate_input_invalid_model(self, provider):
"""Test _validate_input when file contains request with non-existent model."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_model_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "model_not_found"
assert errors[0].line == 1
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
assert errors[0].param == "body.model"
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
input_file_id="url_mismatch_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_url"
assert errors[0].line == 1
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
assert errors[0].param == "url"
async def test_validate_input_multiple_errors_per_request(self, provider):
"""Test _validate_input when a single request has multiple validation errors."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Request missing custom_id, has invalid URL, and missing model in body
mock_response.body = (
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
)
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
input_file_id="multiple_errors_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
assert len(requests) == 0
for error in errors:
assert error.line == 1
error_codes = {error.code for error in errors}
assert "missing_required_parameter" in error_codes # missing custom_id
assert "invalid_url" in error_codes # URL mismatch
async def test_validate_input_invalid_request_format(self, provider):
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'["not", "a", "request", "object"]'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_format_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == "Each line must be a JSON dictionary object"
@pytest.mark.parametrize(
"param_name,param_path,invalid_value,error_message",
[
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
("url", "url", 123, "URL must be a string"),
("method", "method", ["POST"], "Method must be a string"),
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
("model", "body.model", 123, "Model must be a string"),
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
],
)
async def test_validate_input_invalid_parameter_types(
self, provider, param_name, param_path, invalid_value, error_message
):
"""Test _validate_input when file contains request with parameters that have invalid types."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Override the specific parameter with invalid value
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
base_request[top_level][nested_param] = invalid_value
else:
base_request[param_name] = invalid_value
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"invalid_{param_name}_type_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_max_concurrent_batches(self, provider):
"""Test max_concurrent_batches configuration and concurrency control."""
import asyncio
provider._batch_semaphore = asyncio.Semaphore(2)
provider.process_batches = True # enable because we're testing background processing
active_batches = 0
async def add_and_wait(batch_id: str):
nonlocal active_batches
active_batches += 1
await asyncio.sleep(float("inf"))
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
# so we can replace _process_batch_impl with our mock to control concurrency
provider._process_batch_impl = add_and_wait
for _ in range(3):
await provider.create_batch(
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
)
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"