mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
feat: add batches API with OpenAI compatibility (#3088)
Add complete batches API implementation with protocol, providers, and tests: Core Infrastructure: - Add batches API protocol using OpenAI Batch types directly - Add Api.batches enum value and protocol mapping in resolver - Add OpenAI "batch" file purpose support - Include proper error handling (ConflictError, ResourceNotFoundError) Reference Provider: - Add ReferenceBatchesImpl with full CRUD operations (create, retrieve, cancel, list) - Implement background batch processing with configurable concurrency - Add SQLite KVStore backend for persistence - Support /v1/chat/completions endpoint with request validation Comprehensive Test Suite: - Add unit tests for provider implementation with validation - Add integration tests for end-to-end batch processing workflows - Add error handling tests for validation, malformed inputs, and edge cases Configuration: - Add max_concurrent_batches and max_concurrent_requests_per_batch options - Add provider documentation with sample configurations Test with - ``` $ uv run llama stack build --image-type venv --providers inference=YOU_PICK,files=inline::localfs,batches=inline::reference --run & $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest tests/unit/providers/batches tests/integration/batches --text-model YOU_PICK ``` addresses #3066
This commit is contained in:
parent
61582f327c
commit
ed0b7216d0
26 changed files with 2707 additions and 2 deletions
6
docs/_static/llama-stack-spec.html
vendored
6
docs/_static/llama-stack-spec.html
vendored
|
@ -14767,7 +14767,8 @@
|
||||||
"OpenAIFilePurpose": {
|
"OpenAIFilePurpose": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
"assistants"
|
"assistants",
|
||||||
|
"batch"
|
||||||
],
|
],
|
||||||
"title": "OpenAIFilePurpose",
|
"title": "OpenAIFilePurpose",
|
||||||
"description": "Valid purpose values for OpenAI Files API."
|
"description": "Valid purpose values for OpenAI Files API."
|
||||||
|
@ -14844,7 +14845,8 @@
|
||||||
"purpose": {
|
"purpose": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": [
|
"enum": [
|
||||||
"assistants"
|
"assistants",
|
||||||
|
"batch"
|
||||||
],
|
],
|
||||||
"description": "The intended purpose of the file"
|
"description": "The intended purpose of the file"
|
||||||
}
|
}
|
||||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -10951,6 +10951,7 @@ components:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
- assistants
|
- assistants
|
||||||
|
- batch
|
||||||
title: OpenAIFilePurpose
|
title: OpenAIFilePurpose
|
||||||
description: >-
|
description: >-
|
||||||
Valid purpose values for OpenAI Files API.
|
Valid purpose values for OpenAI Files API.
|
||||||
|
@ -11019,6 +11020,7 @@ components:
|
||||||
type: string
|
type: string
|
||||||
enum:
|
enum:
|
||||||
- assistants
|
- assistants
|
||||||
|
- batch
|
||||||
description: The intended purpose of the file
|
description: The intended purpose of the file
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
|
|
|
@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
||||||
- **Batch Inference**: run inference on a dataset of inputs
|
- **Batch Inference**: run inference on a dataset of inputs
|
||||||
- **Batch Agents**: run agents on a dataset of inputs
|
- **Batch Agents**: run agents on a dataset of inputs
|
||||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
- **Synthetic Data Generation**: generate synthetic data for model development
|
||||||
|
- **Batches**: OpenAI-compatible batch management for inference
|
||||||
|
|
|
@ -2,6 +2,15 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
Agents API for creating and interacting with agentic systems.
|
||||||
|
|
||||||
|
Main functionalities provided by this API:
|
||||||
|
- Create agents with specific instructions and ability to use tools.
|
||||||
|
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
|
||||||
|
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
|
||||||
|
- Agents can be provided with various shields (see the Safety API for more details).
|
||||||
|
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **agents** API.
|
This section contains documentation for all available providers for the **agents** API.
|
||||||
|
|
||||||
## Providers
|
## Providers
|
||||||
|
|
21
docs/source/providers/batches/index.md
Normal file
21
docs/source/providers/batches/index.md
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# Batches
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Protocol for batch processing API operations.
|
||||||
|
|
||||||
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
|
||||||
|
This section contains documentation for all available providers for the **batches** API.
|
||||||
|
|
||||||
|
## Providers
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
|
||||||
|
inline_reference
|
||||||
|
```
|
23
docs/source/providers/batches/inline_reference.md
Normal file
23
docs/source/providers/batches/inline_reference.md
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
# inline::reference
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
Reference implementation of batches API with KVStore persistence.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
| Field | Type | Required | Default | Description |
|
||||||
|
|-------|------|----------|---------|-------------|
|
||||||
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
|
||||||
|
| `max_concurrent_batches` | `<class 'int'>` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
|
||||||
|
| `max_concurrent_requests_per_batch` | `<class 'int'>` | No | 10 | Maximum number of concurrent requests to process per batch. |
|
||||||
|
|
||||||
|
## Sample Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
|
||||||
|
|
||||||
|
```
|
||||||
|
|
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
Llama Stack Evaluation API for running evaluations on model and agent candidates.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **eval** API.
|
This section contains documentation for all available providers for the **eval** API.
|
||||||
|
|
||||||
## Providers
|
## Providers
|
||||||
|
|
|
@ -2,6 +2,12 @@
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
|
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||||
|
|
||||||
|
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||||
|
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||||
|
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||||
|
|
||||||
This section contains documentation for all available providers for the **inference** API.
|
This section contains documentation for all available providers for the **inference** API.
|
||||||
|
|
||||||
## Providers
|
## Providers
|
||||||
|
|
9
llama_stack/apis/batches/__init__.py
Normal file
9
llama_stack/apis/batches/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||||
|
|
||||||
|
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
89
llama_stack/apis/batches/batches.py
Normal file
89
llama_stack/apis/batches/batches.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Literal, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
try:
|
||||||
|
from openai.types import Batch as BatchObject
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ListBatchesResponse(BaseModel):
|
||||||
|
"""Response containing a list of batch objects."""
|
||||||
|
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||||
|
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||||
|
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||||
|
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Batches(Protocol):
|
||||||
|
"""Protocol for batch processing API operations.
|
||||||
|
|
||||||
|
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||||
|
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||||
|
cost-effective inference at scale.
|
||||||
|
|
||||||
|
Note: This API is currently under active development and may undergo changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches", method="POST")
|
||||||
|
async def create_batch(
|
||||||
|
self,
|
||||||
|
input_file_id: str,
|
||||||
|
endpoint: str,
|
||||||
|
completion_window: Literal["24h"],
|
||||||
|
metadata: dict[str, str] | None = None,
|
||||||
|
) -> BatchObject:
|
||||||
|
"""Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
|
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||||
|
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||||
|
:param completion_window: The time window within which the batch should be processed.
|
||||||
|
:param metadata: Optional metadata for the batch.
|
||||||
|
:returns: The created batch object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
|
||||||
|
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Retrieve information about a specific batch.
|
||||||
|
|
||||||
|
:param batch_id: The ID of the batch to retrieve.
|
||||||
|
:returns: The batch object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
|
||||||
|
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Cancel a batch that is in progress.
|
||||||
|
|
||||||
|
:param batch_id: The ID of the batch to cancel.
|
||||||
|
:returns: The updated batch object.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/batches", method="GET")
|
||||||
|
async def list_batches(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListBatchesResponse:
|
||||||
|
"""List all batches for the current user.
|
||||||
|
|
||||||
|
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||||
|
:param limit: Number of batches to return (default 20, max 100).
|
||||||
|
:returns: A list of batch objects.
|
||||||
|
"""
|
||||||
|
...
|
|
@ -64,6 +64,12 @@ class SessionNotFoundError(ValueError):
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ConflictError(ValueError):
|
||||||
|
"""raised when an operation cannot be performed due to a conflict with the current state"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ModelTypeError(TypeError):
|
class ModelTypeError(TypeError):
|
||||||
"""raised when a model is present but not the correct type"""
|
"""raised when a model is present but not the correct type"""
|
||||||
|
|
||||||
|
|
|
@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
:cvar inference: Text generation, chat completions, and embeddings
|
:cvar inference: Text generation, chat completions, and embeddings
|
||||||
:cvar safety: Content moderation and safety shields
|
:cvar safety: Content moderation and safety shields
|
||||||
:cvar agents: Agent orchestration and execution
|
:cvar agents: Agent orchestration and execution
|
||||||
|
:cvar batches: Batch processing for asynchronous API requests
|
||||||
:cvar vector_io: Vector database operations and queries
|
:cvar vector_io: Vector database operations and queries
|
||||||
:cvar datasetio: Dataset input/output operations
|
:cvar datasetio: Dataset input/output operations
|
||||||
:cvar scoring: Model output evaluation and scoring
|
:cvar scoring: Model output evaluation and scoring
|
||||||
|
@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
||||||
inference = "inference"
|
inference = "inference"
|
||||||
safety = "safety"
|
safety = "safety"
|
||||||
agents = "agents"
|
agents = "agents"
|
||||||
|
batches = "batches"
|
||||||
vector_io = "vector_io"
|
vector_io = "vector_io"
|
||||||
datasetio = "datasetio"
|
datasetio = "datasetio"
|
||||||
scoring = "scoring"
|
scoring = "scoring"
|
||||||
|
|
|
@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ASSISTANTS = "assistants"
|
ASSISTANTS = "assistants"
|
||||||
|
BATCH = "batch"
|
||||||
# TODO: Add other purposes as needed
|
# TODO: Add other purposes as needed
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import inspect
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.agents import Agents
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.batches import Batches
|
||||||
from llama_stack.apis.benchmarks import Benchmarks
|
from llama_stack.apis.benchmarks import Benchmarks
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
||||||
Api.agents: Agents,
|
Api.agents: Agents,
|
||||||
Api.inference: Inference,
|
Api.inference: Inference,
|
||||||
Api.inspect: Inspect,
|
Api.inspect: Inspect,
|
||||||
|
Api.batches: Batches,
|
||||||
Api.vector_io: VectorIO,
|
Api.vector_io: VectorIO,
|
||||||
Api.vector_dbs: VectorDBs,
|
Api.vector_dbs: VectorDBs,
|
||||||
Api.models: Models,
|
Api.models: Models,
|
||||||
|
|
|
@ -32,6 +32,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||||
from llama_stack.apis.common.responses import PaginatedResponse
|
from llama_stack.apis.common.responses import PaginatedResponse
|
||||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||||
|
@ -128,6 +129,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
elif isinstance(exc, ConflictError):
|
||||||
|
return HTTPException(status_code=409, detail=str(exc))
|
||||||
|
elif isinstance(exc, ResourceNotFoundError):
|
||||||
|
return HTTPException(status_code=404, detail=str(exc))
|
||||||
elif isinstance(exc, ValueError):
|
elif isinstance(exc, ValueError):
|
||||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||||
elif isinstance(exc, BadRequestError):
|
elif isinstance(exc, BadRequestError):
|
||||||
|
|
5
llama_stack/providers/inline/batches/__init__.py
Normal file
5
llama_stack/providers/inline/batches/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.files import Files
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.core.datatypes import AccessRule, Api
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
from .batches import ReferenceBatchesImpl
|
||||||
|
from .config import ReferenceBatchesImplConfig
|
||||||
|
|
||||||
|
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||||
|
kvstore = await kvstore_impl(config.kvstore)
|
||||||
|
inference_api: Inference | None = deps.get(Api.inference)
|
||||||
|
files_api: Files | None = deps.get(Api.files)
|
||||||
|
models_api: Models | None = deps.get(Api.models)
|
||||||
|
|
||||||
|
if inference_api is None:
|
||||||
|
raise ValueError("Inference API is required but not provided in dependencies")
|
||||||
|
if files_api is None:
|
||||||
|
raise ValueError("Files API is required but not provided in dependencies")
|
||||||
|
if models_api is None:
|
||||||
|
raise ValueError("Models API is required but not provided in dependencies")
|
||||||
|
|
||||||
|
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
553
llama_stack/providers/inline/batches/reference/batches.py
Normal file
553
llama_stack/providers/inline/batches/reference/batches.py
Normal file
|
@ -0,0 +1,553 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from openai.types.batch import BatchError, Errors
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||||
|
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||||
|
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
|
||||||
|
from .config import ReferenceBatchesImplConfig
|
||||||
|
|
||||||
|
BATCH_PREFIX = "batch:"
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncBytesIO:
|
||||||
|
"""
|
||||||
|
Async-compatible BytesIO wrapper to allow async file-like operations.
|
||||||
|
|
||||||
|
We use this when uploading files to the Files API, as it expects an
|
||||||
|
async file-like object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data: bytes):
|
||||||
|
self._buffer = BytesIO(data)
|
||||||
|
|
||||||
|
async def read(self, n=-1):
|
||||||
|
return self._buffer.read(n)
|
||||||
|
|
||||||
|
async def seek(self, pos, whence=0):
|
||||||
|
return self._buffer.seek(pos, whence)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self._buffer.close()
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._buffer, name)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequest(BaseModel):
|
||||||
|
line_num: int
|
||||||
|
custom_id: str
|
||||||
|
method: str
|
||||||
|
url: str
|
||||||
|
body: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceBatchesImpl(Batches):
|
||||||
|
"""Reference implementation of the Batches API.
|
||||||
|
|
||||||
|
This implementation processes batch files by making individual requests
|
||||||
|
to the inference API and generates output files with results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: ReferenceBatchesImplConfig,
|
||||||
|
inference_api: Inference,
|
||||||
|
files_api: Files,
|
||||||
|
models_api: Models,
|
||||||
|
kvstore: KVStore,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.kvstore = kvstore
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.files_api = files_api
|
||||||
|
self.models_api = models_api
|
||||||
|
self._processing_tasks: dict[str, asyncio.Task] = {}
|
||||||
|
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
|
||||||
|
self._update_batch_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# this is to allow tests to disable background processing
|
||||||
|
self.process_batches = True
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
# TODO: start background processing of existing tasks
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Shutdown the batches provider."""
|
||||||
|
if self._processing_tasks:
|
||||||
|
# don't cancel tasks - just let them stop naturally on shutdown
|
||||||
|
# cancelling would mark batches as "cancelled" in the database
|
||||||
|
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
||||||
|
|
||||||
|
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||||
|
async def create_batch(
|
||||||
|
self,
|
||||||
|
input_file_id: str,
|
||||||
|
endpoint: str,
|
||||||
|
completion_window: Literal["24h"],
|
||||||
|
metadata: dict[str, str] | None = None,
|
||||||
|
) -> BatchObject:
|
||||||
|
"""
|
||||||
|
Create a new batch for processing multiple API requests.
|
||||||
|
|
||||||
|
Error handling by levels -
|
||||||
|
0. Input param handling, results in 40x errors before processing, e.g.
|
||||||
|
- Wrong completion_window
|
||||||
|
- Invalid metadata types
|
||||||
|
- Unknown endpoint
|
||||||
|
-> no batch created
|
||||||
|
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||||
|
- input_file_id missing
|
||||||
|
- invalid json in file
|
||||||
|
- missing custom_id, method, url, body
|
||||||
|
- invalid model
|
||||||
|
- streaming
|
||||||
|
-> batch created, validation sends to failed status
|
||||||
|
2. Processing errors, result in error_file_id entries, e.g.
|
||||||
|
- Any error returned from inference endpoint
|
||||||
|
-> batch created, goes to completed status
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: set expiration time for garbage collection
|
||||||
|
|
||||||
|
if endpoint not in ["/v1/chat/completions"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||||
|
)
|
||||||
|
|
||||||
|
if completion_window != "24h":
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||||
|
current_time = int(time.time())
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id=batch_id,
|
||||||
|
object="batch",
|
||||||
|
endpoint=endpoint,
|
||||||
|
input_file_id=input_file_id,
|
||||||
|
completion_window=completion_window,
|
||||||
|
status="validating",
|
||||||
|
created_at=current_time,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||||
|
|
||||||
|
if self.process_batches:
|
||||||
|
task = asyncio.create_task(self._process_batch(batch_id))
|
||||||
|
self._processing_tasks[batch_id] = task
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Cancel a batch that is in progress."""
|
||||||
|
batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
if batch.status in ["cancelled", "cancelling"]:
|
||||||
|
return batch
|
||||||
|
|
||||||
|
if batch.status in ["completed", "failed", "expired"]:
|
||||||
|
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||||
|
|
||||||
|
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||||
|
|
||||||
|
if batch_id in self._processing_tasks:
|
||||||
|
self._processing_tasks[batch_id].cancel()
|
||||||
|
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||||
|
|
||||||
|
return await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
async def list_batches(
|
||||||
|
self,
|
||||||
|
after: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> ListBatchesResponse:
|
||||||
|
"""
|
||||||
|
List all batches, eventually only for the current user.
|
||||||
|
|
||||||
|
With no notion of user, we return all batches.
|
||||||
|
"""
|
||||||
|
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
|
||||||
|
|
||||||
|
batches = []
|
||||||
|
for batch_data in batch_values:
|
||||||
|
if batch_data:
|
||||||
|
batches.append(BatchObject.model_validate_json(batch_data))
|
||||||
|
|
||||||
|
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
if after:
|
||||||
|
for i, batch in enumerate(batches):
|
||||||
|
if batch.id == after:
|
||||||
|
start_idx = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
page_batches = batches[start_idx : start_idx + limit]
|
||||||
|
has_more = (start_idx + limit) < len(batches)
|
||||||
|
|
||||||
|
first_id = page_batches[0].id if page_batches else None
|
||||||
|
last_id = page_batches[-1].id if page_batches else None
|
||||||
|
|
||||||
|
return ListBatchesResponse(
|
||||||
|
data=page_batches,
|
||||||
|
first_id=first_id,
|
||||||
|
last_id=last_id,
|
||||||
|
has_more=has_more,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||||
|
"""Retrieve information about a specific batch."""
|
||||||
|
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||||
|
if not batch_data:
|
||||||
|
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||||
|
|
||||||
|
return BatchObject.model_validate_json(batch_data)
|
||||||
|
|
||||||
|
async def _update_batch(self, batch_id: str, **updates) -> None:
|
||||||
|
"""Update batch fields in kvstore."""
|
||||||
|
async with self._update_batch_lock:
|
||||||
|
try:
|
||||||
|
batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||||
|
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||||
|
logger.info(
|
||||||
|
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if "errors" in updates:
|
||||||
|
updates["errors"] = updates["errors"].model_dump()
|
||||||
|
|
||||||
|
batch_dict = batch.model_dump()
|
||||||
|
batch_dict.update(updates)
|
||||||
|
|
||||||
|
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update batch {batch_id}: {e}")
|
||||||
|
|
||||||
|
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
|
||||||
|
"""
|
||||||
|
Read & validate input, return errors and valid input.
|
||||||
|
|
||||||
|
Validation of
|
||||||
|
- input_file_id existance
|
||||||
|
- valid json
|
||||||
|
- custom_id, method, url, body presence and valid
|
||||||
|
- no streaming
|
||||||
|
"""
|
||||||
|
requests: list[BatchRequest] = []
|
||||||
|
errors: list[BatchError] = []
|
||||||
|
try:
|
||||||
|
await self.files_api.openai_retrieve_file(batch.input_file_id)
|
||||||
|
except Exception:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=None,
|
||||||
|
message=f"Cannot find file {batch.input_file_id}.",
|
||||||
|
param="input_file_id",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return errors, requests
|
||||||
|
|
||||||
|
# TODO(SECURITY): do something about large files
|
||||||
|
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||||
|
file_content = file_content_response.body.decode("utf-8")
|
||||||
|
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||||
|
if line.strip(): # skip empty lines
|
||||||
|
try:
|
||||||
|
request = json.loads(line)
|
||||||
|
|
||||||
|
if not isinstance(request, dict):
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message="Each line must be a JSON dictionary object",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid = True
|
||||||
|
|
||||||
|
for param, expected_type, type_string in [
|
||||||
|
("custom_id", str, "string"),
|
||||||
|
("method", str, "string"),
|
||||||
|
("url", str, "string"),
|
||||||
|
("body", dict, "JSON dictionary object"),
|
||||||
|
]:
|
||||||
|
if param not in request:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="missing_required_parameter",
|
||||||
|
line=line_num,
|
||||||
|
message=f"Missing required parameter: {param}",
|
||||||
|
param=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
elif not isinstance(request[param], expected_type):
|
||||||
|
param_name = "URL" if param == "url" else param.capitalize()
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message=f"{param_name} must be a {type_string}",
|
||||||
|
param=param,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_url",
|
||||||
|
line=line_num,
|
||||||
|
message="URL provided for this request does not match the batch endpoint",
|
||||||
|
param="url",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if (body := request.get("body")) and isinstance(body, dict):
|
||||||
|
if body.get("stream", False):
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="streaming_unsupported",
|
||||||
|
line=line_num,
|
||||||
|
message="Streaming is not supported in batch processing",
|
||||||
|
param="body.stream",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
for param, expected_type, type_string in [
|
||||||
|
("model", str, "a string"),
|
||||||
|
# messages is specific to /v1/chat/completions
|
||||||
|
# we could skip validating messages here and let inference fail. however,
|
||||||
|
# that would be a very expensive way to find out messages is wrong.
|
||||||
|
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||||
|
]:
|
||||||
|
if param not in body:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message=f"{param.capitalize()} parameter is required",
|
||||||
|
param=f"body.{param}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
elif not isinstance(body[param], expected_type):
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_request",
|
||||||
|
line=line_num,
|
||||||
|
message=f"{param.capitalize()} must be {type_string}",
|
||||||
|
param=f"body.{param}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if "model" in body and isinstance(body["model"], str):
|
||||||
|
try:
|
||||||
|
await self.models_api.get_model(body["model"])
|
||||||
|
except Exception:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="model_not_found",
|
||||||
|
line=line_num,
|
||||||
|
message=f"Model '{body['model']}' does not exist or is not supported",
|
||||||
|
param="body.model",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
valid = False
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||||
|
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
||||||
|
requests.append(
|
||||||
|
BatchRequest(
|
||||||
|
line_num=line_num,
|
||||||
|
url=url,
|
||||||
|
method=request["method"],
|
||||||
|
custom_id=request["custom_id"],
|
||||||
|
body=body,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
errors.append(
|
||||||
|
BatchError(
|
||||||
|
code="invalid_json_line",
|
||||||
|
line=line_num,
|
||||||
|
message="This line is not parseable as valid JSON.",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return errors, requests
|
||||||
|
|
||||||
|
async def _process_batch(self, batch_id: str) -> None:
|
||||||
|
"""Background task to process a batch of requests."""
|
||||||
|
try:
|
||||||
|
logger.info(f"Starting batch processing for {batch_id}")
|
||||||
|
async with self._batch_semaphore: # semaphore to limit concurrency
|
||||||
|
logger.info(f"Acquired semaphore for batch {batch_id}")
|
||||||
|
await self._process_batch_impl(batch_id)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Batch processing cancelled for {batch_id}")
|
||||||
|
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Batch processing failed for {batch_id}: {e}")
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
status="failed",
|
||||||
|
failed_at=int(time.time()),
|
||||||
|
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._processing_tasks.pop(batch_id, None)
|
||||||
|
|
||||||
|
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||||
|
"""Implementation of batch processing logic."""
|
||||||
|
errors: list[BatchError] = []
|
||||||
|
batch = await self.retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
errors, requests = await self._validate_input(batch)
|
||||||
|
if errors:
|
||||||
|
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
|
||||||
|
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
|
||||||
|
|
||||||
|
total_requests = len(requests)
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
status="in_progress",
|
||||||
|
request_counts={"total": total_requests, "completed": 0, "failed": 0},
|
||||||
|
)
|
||||||
|
|
||||||
|
error_results = []
|
||||||
|
success_results = []
|
||||||
|
completed_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
|
||||||
|
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
|
||||||
|
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
|
||||||
|
|
||||||
|
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for result in chunk_results:
|
||||||
|
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
|
||||||
|
failed_count += 1
|
||||||
|
error_results.append(result)
|
||||||
|
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
|
||||||
|
completed_count += 1
|
||||||
|
success_results.append(result)
|
||||||
|
else: # unexpected result
|
||||||
|
failed_count += 1
|
||||||
|
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
|
||||||
|
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
|
||||||
|
)
|
||||||
|
|
||||||
|
if errors:
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
output_file_id = await self._create_output_file(batch_id, success_results, "success")
|
||||||
|
await self._update_batch(batch_id, output_file_id=output_file_id)
|
||||||
|
|
||||||
|
error_file_id = await self._create_output_file(batch_id, error_results, "error")
|
||||||
|
await self._update_batch(batch_id, error_file_id=error_file_id)
|
||||||
|
|
||||||
|
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# note: errors is empty at this point, so we don't lose anything by ignoring it
|
||||||
|
await self._update_batch(
|
||||||
|
batch_id,
|
||||||
|
status="failed",
|
||||||
|
failed_at=int(time.time()),
|
||||||
|
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
|
||||||
|
"""Process a single request from the batch."""
|
||||||
|
request_id = f"batch_req_{batch_id}_{request.line_num}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO(SECURITY): review body for security issues
|
||||||
|
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||||
|
|
||||||
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
|
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||||
|
return {
|
||||||
|
"id": request_id,
|
||||||
|
"custom_id": request.custom_id,
|
||||||
|
"response": {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": request_id, # TODO: should this be different?
|
||||||
|
"body": chat_response.model_dump_json(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||||
|
return {
|
||||||
|
"id": request_id,
|
||||||
|
"custom_id": request.custom_id,
|
||||||
|
"error": {"type": "request_failed", "message": str(e)},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
|
||||||
|
"""
|
||||||
|
Create an output file with batch results.
|
||||||
|
|
||||||
|
This function filters results based on the specified file_type
|
||||||
|
and uploads the file to the Files API.
|
||||||
|
"""
|
||||||
|
output_lines = [json.dumps(result) for result in results]
|
||||||
|
|
||||||
|
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
|
||||||
|
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
|
||||||
|
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||||
|
return uploaded_file.id
|
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ReferenceBatchesImplConfig(BaseModel):
|
||||||
|
"""Configuration for the Reference Batches implementation."""
|
||||||
|
|
||||||
|
kvstore: KVStoreConfig = Field(
|
||||||
|
description="Configuration for the key-value store backend.",
|
||||||
|
)
|
||||||
|
|
||||||
|
max_concurrent_batches: int = Field(
|
||||||
|
default=1,
|
||||||
|
description="Maximum number of concurrent batches to process simultaneously.",
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_concurrent_requests_per_batch: int = Field(
|
||||||
|
default=10,
|
||||||
|
description="Maximum number of concurrent requests to process per batch.",
|
||||||
|
ge=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: add a max requests per second rate limiter
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||||
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="batches.db",
|
||||||
|
),
|
||||||
|
}
|
26
llama_stack/providers/registry/batches.py
Normal file
26
llama_stack/providers/registry/batches.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> list[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.batches,
|
||||||
|
provider_type="inline::reference",
|
||||||
|
pip_packages=["openai"],
|
||||||
|
module="llama_stack.providers.inline.batches.reference",
|
||||||
|
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.inference,
|
||||||
|
Api.files,
|
||||||
|
Api.models,
|
||||||
|
],
|
||||||
|
description="Reference implementation of batches API with KVStore persistence.",
|
||||||
|
),
|
||||||
|
]
|
|
@ -18,6 +18,23 @@ from llama_stack.core.distribution import get_provider_registry
|
||||||
REPO_ROOT = Path(__file__).parent.parent
|
REPO_ROOT = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_docstring(api_name: str) -> str | None:
|
||||||
|
"""Extract docstring from the API protocol class."""
|
||||||
|
try:
|
||||||
|
# Import the API module dynamically
|
||||||
|
api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()])
|
||||||
|
|
||||||
|
# Get the main protocol class (usually capitalized API name)
|
||||||
|
protocol_class_name = api_name.title()
|
||||||
|
if hasattr(api_module, protocol_class_name):
|
||||||
|
protocol_class = getattr(api_module, protocol_class_name)
|
||||||
|
return protocol_class.__doc__
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ChangedPathTracker:
|
class ChangedPathTracker:
|
||||||
"""Track a list of paths we may have changed."""
|
"""Track a list of paths we may have changed."""
|
||||||
|
|
||||||
|
@ -261,6 +278,11 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
|
||||||
index_content.append(f"# {api_name.title()}\n")
|
index_content.append(f"# {api_name.title()}\n")
|
||||||
index_content.append("## Overview\n")
|
index_content.append("## Overview\n")
|
||||||
|
|
||||||
|
api_docstring = get_api_docstring(api_name)
|
||||||
|
if api_docstring:
|
||||||
|
cleaned_docstring = api_docstring.strip()
|
||||||
|
index_content.append(f"{cleaned_docstring}\n")
|
||||||
|
|
||||||
index_content.append(
|
index_content.append(
|
||||||
f"This section contains documentation for all available providers for the **{api_name}** API.\n"
|
f"This section contains documentation for all available providers for the **{api_name}** API.\n"
|
||||||
)
|
)
|
||||||
|
|
5
tests/integration/batches/__init__.py
Normal file
5
tests/integration/batches/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
122
tests/integration/batches/conftest.py
Normal file
122
tests/integration/batches/conftest.py
Normal file
|
@ -0,0 +1,122 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""Shared pytest fixtures for batch tests."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.files import OpenAIFilePurpose
|
||||||
|
|
||||||
|
|
||||||
|
class BatchHelper:
|
||||||
|
"""Helper class for creating and managing batch input files."""
|
||||||
|
|
||||||
|
def __init__(self, client):
|
||||||
|
"""Initialize with either a batch_client or openai_client."""
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
|
||||||
|
"""Context manager for creating and cleaning up batch input files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Either a list of batch request dictionaries or raw string content
|
||||||
|
filename_prefix: Prefix for the generated filename (or full filename if content is string)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The uploaded file object
|
||||||
|
"""
|
||||||
|
if isinstance(content, str):
|
||||||
|
# Handle raw string content (e.g., malformed JSONL, empty files)
|
||||||
|
file_content = content.encode("utf-8")
|
||||||
|
else:
|
||||||
|
# Handle list of batch request dictionaries
|
||||||
|
jsonl_content = "\n".join(json.dumps(req) for req in content)
|
||||||
|
file_content = jsonl_content.encode("utf-8")
|
||||||
|
|
||||||
|
filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
|
||||||
|
|
||||||
|
with BytesIO(file_content) as file_buffer:
|
||||||
|
file_buffer.name = filename
|
||||||
|
uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield uploaded_file
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
self.client.files.delete(uploaded_file.id)
|
||||||
|
except Exception:
|
||||||
|
warnings.warn(
|
||||||
|
f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def wait_for(
|
||||||
|
self,
|
||||||
|
batch_id: str,
|
||||||
|
max_wait_time: int = 60,
|
||||||
|
sleep_interval: int | None = None,
|
||||||
|
expected_statuses: set[str] | None = None,
|
||||||
|
timeout_action: str = "fail",
|
||||||
|
):
|
||||||
|
"""Wait for a batch to reach a terminal status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_id: The batch ID to monitor
|
||||||
|
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
|
||||||
|
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
|
||||||
|
expected_statuses: Set of expected terminal statuses (default: {"completed"})
|
||||||
|
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The final batch object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
|
||||||
|
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
|
||||||
|
"""
|
||||||
|
if sleep_interval is None:
|
||||||
|
# Default to 1/10th of max_wait_time, with min 1s and max 15s
|
||||||
|
sleep_interval = max(1, min(15, max_wait_time // 10))
|
||||||
|
|
||||||
|
if expected_statuses is None:
|
||||||
|
expected_statuses = {"completed"}
|
||||||
|
|
||||||
|
terminal_statuses = {"completed", "failed", "cancelled", "expired"}
|
||||||
|
unexpected_statuses = terminal_statuses - expected_statuses
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < max_wait_time:
|
||||||
|
current_batch = self.client.batches.retrieve(batch_id)
|
||||||
|
|
||||||
|
if current_batch.status in expected_statuses:
|
||||||
|
return current_batch
|
||||||
|
elif current_batch.status in unexpected_statuses:
|
||||||
|
error_msg = f"Batch reached unexpected status: {current_batch.status}"
|
||||||
|
if timeout_action == "skip":
|
||||||
|
pytest.skip(error_msg)
|
||||||
|
else:
|
||||||
|
pytest.fail(error_msg)
|
||||||
|
|
||||||
|
time.sleep(sleep_interval)
|
||||||
|
|
||||||
|
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
|
||||||
|
if timeout_action == "skip":
|
||||||
|
pytest.skip(timeout_msg)
|
||||||
|
else:
|
||||||
|
pytest.fail(timeout_msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def batch_helper(openai_client):
|
||||||
|
"""Fixture that provides a BatchHelper instance for OpenAI client."""
|
||||||
|
return BatchHelper(openai_client)
|
270
tests/integration/batches/test_batches.py
Normal file
270
tests/integration/batches/test_batches.py
Normal file
|
@ -0,0 +1,270 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Integration tests for the Llama Stack batch processing functionality.
|
||||||
|
|
||||||
|
This module contains comprehensive integration tests for the batch processing API,
|
||||||
|
using the OpenAI-compatible client interface for consistency.
|
||||||
|
|
||||||
|
Test Categories:
|
||||||
|
1. Core Batch Operations:
|
||||||
|
- test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval
|
||||||
|
- test_batch_listing: Basic batch listing functionality
|
||||||
|
- test_batch_immediate_cancellation: Batch cancellation workflow
|
||||||
|
# TODO: cancel during processing
|
||||||
|
|
||||||
|
2. End-to-End Processing:
|
||||||
|
- test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation
|
||||||
|
|
||||||
|
Note: Error conditions and edge cases are primarily tested in test_batches_errors.py
|
||||||
|
for better organization and separation of concerns.
|
||||||
|
|
||||||
|
CLEANUP WARNING: These tests currently create batches that are not automatically
|
||||||
|
cleaned up after test completion. This may lead to resource accumulation over
|
||||||
|
multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch.
|
||||||
|
The test_batch_e2e_chat_completions test does clean up its output and error files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchesIntegration:
|
||||||
|
"""Integration tests for the batches API."""
|
||||||
|
|
||||||
|
def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test comprehensive batch creation and retrieval scenarios."""
|
||||||
|
test_metadata = {
|
||||||
|
"test_type": "comprehensive",
|
||||||
|
"purpose": "creation_and_retrieval_test",
|
||||||
|
"version": "1.0",
|
||||||
|
"tags": "test,batch",
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "request-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata=test_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch.endpoint == "/v1/chat/completions"
|
||||||
|
assert batch.input_file_id == uploaded_file.id
|
||||||
|
assert batch.completion_window == "24h"
|
||||||
|
assert batch.metadata == test_metadata
|
||||||
|
|
||||||
|
retrieved_batch = openai_client.batches.retrieve(batch.id)
|
||||||
|
|
||||||
|
assert retrieved_batch.id == batch.id
|
||||||
|
assert retrieved_batch.object == batch.object
|
||||||
|
assert retrieved_batch.endpoint == batch.endpoint
|
||||||
|
assert retrieved_batch.input_file_id == batch.input_file_id
|
||||||
|
assert retrieved_batch.completion_window == batch.completion_window
|
||||||
|
assert retrieved_batch.metadata == batch.metadata
|
||||||
|
|
||||||
|
def test_batch_listing(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""
|
||||||
|
Test batch listing.
|
||||||
|
|
||||||
|
This test creates multiple batches and verifies that they can be listed.
|
||||||
|
It also deletes the input files before execution, which means the batches
|
||||||
|
will appear as failed due to missing input files. This is expected and
|
||||||
|
a good thing, because it means no inference is performed.
|
||||||
|
"""
|
||||||
|
batch_ids = []
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": f"request-{i}",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": f"Hello {i}"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
batch_ids.append(batch.id)
|
||||||
|
|
||||||
|
batch_list = openai_client.batches.list()
|
||||||
|
|
||||||
|
assert isinstance(batch_list.data, list)
|
||||||
|
|
||||||
|
listed_batch_ids = {b.id for b in batch_list.data}
|
||||||
|
for batch_id in batch_ids:
|
||||||
|
assert batch_id in listed_batch_ids
|
||||||
|
|
||||||
|
def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test immediate batch cancellation."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "request-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
# hopefully cancel the batch before it completes
|
||||||
|
cancelling_batch = openai_client.batches.cancel(batch.id)
|
||||||
|
assert cancelling_batch.status in ["cancelling", "cancelled"]
|
||||||
|
assert isinstance(cancelling_batch.cancelling_at, int), (
|
||||||
|
f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(
|
||||||
|
batch.id,
|
||||||
|
max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min
|
||||||
|
expected_statuses={"cancelled"},
|
||||||
|
timeout_action="skip",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_batch.status == "cancelled"
|
||||||
|
assert isinstance(final_batch.cancelled_at, int), (
|
||||||
|
f"cancelled_at should be int, got {type(final_batch.cancelled_at)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test end-to-end batch processing for chat completions with both successful and failed operations."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "success-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Say hello"}],
|
||||||
|
"max_tokens": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "error-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "This should fail"}],
|
||||||
|
"max_tokens": -1, # Invalid negative max_tokens will cause inference error
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata={"test": "e2e_success_and_errors_test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(
|
||||||
|
batch.id,
|
||||||
|
max_wait_time=3 * 60, # often takes 2-3 minutes
|
||||||
|
expected_statuses={"completed"},
|
||||||
|
timeout_action="skip",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expecting a completed batch with both successful and failed requests
|
||||||
|
# Batch(id='batch_xxx',
|
||||||
|
# completion_window='24h',
|
||||||
|
# created_at=...,
|
||||||
|
# endpoint='/v1/chat/completions',
|
||||||
|
# input_file_id='file-xxx',
|
||||||
|
# object='batch',
|
||||||
|
# status='completed',
|
||||||
|
# output_file_id='file-xxx',
|
||||||
|
# error_file_id='file-xxx',
|
||||||
|
# request_counts=BatchRequestCounts(completed=1, failed=1, total=2))
|
||||||
|
|
||||||
|
assert final_batch.status == "completed"
|
||||||
|
assert final_batch.request_counts is not None
|
||||||
|
assert final_batch.request_counts.total == 2
|
||||||
|
assert final_batch.request_counts.completed == 1
|
||||||
|
assert final_batch.request_counts.failed == 1
|
||||||
|
|
||||||
|
assert final_batch.output_file_id is not None, "Output file should exist for successful requests"
|
||||||
|
|
||||||
|
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||||
|
if isinstance(output_content, str):
|
||||||
|
output_text = output_content
|
||||||
|
else:
|
||||||
|
output_text = output_content.content.decode("utf-8")
|
||||||
|
|
||||||
|
output_lines = output_text.strip().split("\n")
|
||||||
|
|
||||||
|
for line in output_lines:
|
||||||
|
result = json.loads(line)
|
||||||
|
|
||||||
|
assert "id" in result
|
||||||
|
assert "custom_id" in result
|
||||||
|
assert result["custom_id"] == "success-1"
|
||||||
|
|
||||||
|
assert "response" in result
|
||||||
|
|
||||||
|
assert result["response"]["status_code"] == 200
|
||||||
|
assert "body" in result["response"]
|
||||||
|
assert "choices" in result["response"]["body"]
|
||||||
|
|
||||||
|
assert final_batch.error_file_id is not None, "Error file should exist for failed requests"
|
||||||
|
|
||||||
|
error_content = openai_client.files.content(final_batch.error_file_id)
|
||||||
|
if isinstance(error_content, str):
|
||||||
|
error_text = error_content
|
||||||
|
else:
|
||||||
|
error_text = error_content.content.decode("utf-8")
|
||||||
|
|
||||||
|
error_lines = error_text.strip().split("\n")
|
||||||
|
|
||||||
|
for line in error_lines:
|
||||||
|
result = json.loads(line)
|
||||||
|
|
||||||
|
assert "id" in result
|
||||||
|
assert "custom_id" in result
|
||||||
|
assert result["custom_id"] == "error-1"
|
||||||
|
assert "error" in result
|
||||||
|
error = result["error"]
|
||||||
|
assert error is not None
|
||||||
|
assert "code" in error or "message" in error, "Error should have code or message"
|
||||||
|
|
||||||
|
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||||
|
assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully"
|
||||||
|
|
||||||
|
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||||
|
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
693
tests/integration/batches/test_batches_errors.py
Normal file
693
tests/integration/batches/test_batches_errors.py
Normal file
|
@ -0,0 +1,693 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Error handling and edge case tests for the Llama Stack batch processing functionality.
|
||||||
|
|
||||||
|
This module focuses exclusively on testing error conditions, validation failures,
|
||||||
|
and edge cases for batch operations to ensure robust error handling and graceful
|
||||||
|
degradation.
|
||||||
|
|
||||||
|
Test Categories:
|
||||||
|
1. File and Input Validation:
|
||||||
|
- test_batch_nonexistent_file_id: Handling invalid file IDs
|
||||||
|
- test_batch_malformed_jsonl: Processing malformed JSONL input files
|
||||||
|
- test_file_malformed_batch_file: Handling malformed files at upload time
|
||||||
|
- test_batch_missing_required_fields: Validation of required request fields
|
||||||
|
|
||||||
|
2. API Endpoint and Model Validation:
|
||||||
|
- test_batch_invalid_endpoint: Invalid endpoint handling during creation
|
||||||
|
- test_batch_error_handling_invalid_model: Error handling with nonexistent models
|
||||||
|
- test_batch_endpoint_mismatch: Validation of endpoint/URL consistency
|
||||||
|
|
||||||
|
3. Batch Lifecycle Error Handling:
|
||||||
|
- test_batch_retrieve_nonexistent: Retrieving non-existent batches
|
||||||
|
- test_batch_cancel_nonexistent: Cancelling non-existent batches
|
||||||
|
- test_batch_cancel_completed: Attempting to cancel completed batches
|
||||||
|
|
||||||
|
4. Parameter and Configuration Validation:
|
||||||
|
- test_batch_invalid_completion_window: Invalid completion window values
|
||||||
|
- test_batch_invalid_metadata_types: Invalid metadata type validation
|
||||||
|
- test_batch_missing_required_body_fields: Validation of required fields in request body
|
||||||
|
|
||||||
|
5. Feature Restriction and Compatibility:
|
||||||
|
- test_batch_streaming_not_supported: Streaming request rejection
|
||||||
|
- test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation
|
||||||
|
|
||||||
|
Note: Core functionality and OpenAI compatibility tests are located in
|
||||||
|
test_batches_integration.py for better organization and separation of concerns.
|
||||||
|
|
||||||
|
CLEANUP WARNING: These tests create batches to test error conditions but do not
|
||||||
|
automatically clean them up after test completion. While most error tests create
|
||||||
|
batches that fail quickly, some may create valid batches that consume resources.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai import BadRequestError, ConflictError, NotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchesErrorHandling:
|
||||||
|
"""Error handling and edge case tests for the batches API using OpenAI client."""
|
||||||
|
|
||||||
|
def test_batch_nonexistent_file_id(self, openai_client, batch_helper):
|
||||||
|
"""Test batch creation with nonexistent input file ID."""
|
||||||
|
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id="file-nonexistent-xyz",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(data=[
|
||||||
|
# BatchError(
|
||||||
|
# code='invalid_request',
|
||||||
|
# line=None,
|
||||||
|
# message='Cannot find file ..., or organization ... does not have access to it.',
|
||||||
|
# param='file_id')
|
||||||
|
# ], object='list'),
|
||||||
|
# failed_at=1754566971,
|
||||||
|
# ...)
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) == 1
|
||||||
|
error = final_batch.errors.data[0]
|
||||||
|
assert error.code == "invalid_request"
|
||||||
|
assert "cannot find file" in error.message.lower()
|
||||||
|
|
||||||
|
def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test batch creation with invalid endpoint."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "invalid-endpoint",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
with pytest.raises(BadRequestError) as exc_info:
|
||||||
|
openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/invalid/endpoint",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expected -
|
||||||
|
# Error code: 400 - {
|
||||||
|
# 'error': {
|
||||||
|
# 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.",
|
||||||
|
# 'type': 'invalid_request_error',
|
||||||
|
# 'param': 'endpoint',
|
||||||
|
# 'code': 'invalid_value'
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value).lower()
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "invalid value" in error_msg
|
||||||
|
assert "/v1/invalid/endpoint" in error_msg
|
||||||
|
assert "supported values" in error_msg
|
||||||
|
assert "endpoint" in error_msg
|
||||||
|
assert "invalid_value" in error_msg
|
||||||
|
|
||||||
|
def test_batch_malformed_jsonl(self, openai_client, batch_helper):
|
||||||
|
"""
|
||||||
|
Test batch with malformed JSONL input.
|
||||||
|
|
||||||
|
The /v1/files endpoint requires valid JSONL format, so we provide a well formed line
|
||||||
|
before a malformed line to ensure we get to the /v1/batches validation stage.
|
||||||
|
"""
|
||||||
|
with batch_helper.create_file(
|
||||||
|
"""{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}}
|
||||||
|
{invalid json here""",
|
||||||
|
"malformed_batch_input.jsonl",
|
||||||
|
) as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(data=[
|
||||||
|
# ...,
|
||||||
|
# BatchError(code='invalid_json_line',
|
||||||
|
# line=2,
|
||||||
|
# message='This line is not parseable as valid JSON.',
|
||||||
|
# param=None)
|
||||||
|
# ], object='list'),
|
||||||
|
# ...)
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) > 0
|
||||||
|
error = final_batch.errors.data[-1] # get last error because first may be about the "test" model
|
||||||
|
assert error.code == "invalid_json_line"
|
||||||
|
assert error.line == 2
|
||||||
|
assert "not" in error.message.lower()
|
||||||
|
assert "valid json" in error.message.lower()
|
||||||
|
|
||||||
|
@pytest.mark.xfail(reason="Not all file providers validate content")
|
||||||
|
@pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"])
|
||||||
|
def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests):
|
||||||
|
"""Test file upload with malformed content."""
|
||||||
|
|
||||||
|
with pytest.raises(BadRequestError) as exc_info:
|
||||||
|
with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"):
|
||||||
|
# /v1/files rejects the file, we don't get to batch creation
|
||||||
|
pass
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value).lower()
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "invalid file format" in error_msg
|
||||||
|
assert "jsonl" in error_msg
|
||||||
|
|
||||||
|
def test_batch_retrieve_nonexistent(self, openai_client):
|
||||||
|
"""Test retrieving nonexistent batch."""
|
||||||
|
with pytest.raises(NotFoundError) as exc_info:
|
||||||
|
openai_client.batches.retrieve("batch-nonexistent-xyz")
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value).lower()
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
assert "no batch found" in error_msg or "not found" in error_msg
|
||||||
|
|
||||||
|
def test_batch_cancel_nonexistent(self, openai_client):
|
||||||
|
"""Test cancelling nonexistent batch."""
|
||||||
|
with pytest.raises(NotFoundError) as exc_info:
|
||||||
|
openai_client.batches.cancel("batch-nonexistent-xyz")
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value).lower()
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
assert "no batch found" in error_msg or "not found" in error_msg
|
||||||
|
|
||||||
|
def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test cancelling already completed batch."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "cancel-completed",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Quick test"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(
|
||||||
|
batch.id,
|
||||||
|
max_wait_time=3 * 60, # often take 10-11 min, give it 3 min
|
||||||
|
expected_statuses={"completed"},
|
||||||
|
timeout_action="skip",
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted_file = openai_client.files.delete(final_batch.output_file_id)
|
||||||
|
assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully"
|
||||||
|
|
||||||
|
with pytest.raises(ConflictError) as exc_info:
|
||||||
|
openai_client.batches.cancel(batch.id)
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Error code: 409 - {
|
||||||
|
# 'error': {
|
||||||
|
# 'message': "Cannot cancel a batch with status 'completed'.",
|
||||||
|
# 'type': 'invalid_request_error',
|
||||||
|
# 'param': None,
|
||||||
|
# 'code': None
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# NOTE: Same for "failed", cancelling "cancelled" batches is allowed
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value).lower()
|
||||||
|
assert exc_info.value.status_code == 409
|
||||||
|
assert "cannot cancel" in error_msg
|
||||||
|
|
||||||
|
def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test batch with requests missing required fields."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
# Missing custom_id
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "No custom_id"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "no-method",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "No method"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "no-url",
|
||||||
|
"method": "POST",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "No URL"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "no-body",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(
|
||||||
|
# data=[
|
||||||
|
# BatchError(
|
||||||
|
# code='missing_required_parameter',
|
||||||
|
# line=1,
|
||||||
|
# message="Missing required parameter: 'custom_id'.",
|
||||||
|
# param='custom_id'
|
||||||
|
# ),
|
||||||
|
# BatchError(
|
||||||
|
# code='missing_required_parameter',
|
||||||
|
# line=2,
|
||||||
|
# message="Missing required parameter: 'method'.",
|
||||||
|
# param='method'
|
||||||
|
# ),
|
||||||
|
# BatchError(
|
||||||
|
# code='missing_required_parameter',
|
||||||
|
# line=3,
|
||||||
|
# message="Missing required parameter: 'url'.",
|
||||||
|
# param='url'
|
||||||
|
# ),
|
||||||
|
# BatchError(
|
||||||
|
# code='missing_required_parameter',
|
||||||
|
# line=4,
|
||||||
|
# message="Missing required parameter: 'body'.",
|
||||||
|
# param='body'
|
||||||
|
# )
|
||||||
|
# ], object='list'),
|
||||||
|
# failed_at=1754566945,
|
||||||
|
# ...)
|
||||||
|
# )
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) == 4
|
||||||
|
no_custom_id_error = final_batch.errors.data[0]
|
||||||
|
assert no_custom_id_error.code == "missing_required_parameter"
|
||||||
|
assert no_custom_id_error.line == 1
|
||||||
|
assert "missing" in no_custom_id_error.message.lower()
|
||||||
|
assert "custom_id" in no_custom_id_error.message.lower()
|
||||||
|
no_method_error = final_batch.errors.data[1]
|
||||||
|
assert no_method_error.code == "missing_required_parameter"
|
||||||
|
assert no_method_error.line == 2
|
||||||
|
assert "missing" in no_method_error.message.lower()
|
||||||
|
assert "method" in no_method_error.message.lower()
|
||||||
|
no_url_error = final_batch.errors.data[2]
|
||||||
|
assert no_url_error.code == "missing_required_parameter"
|
||||||
|
assert no_url_error.line == 3
|
||||||
|
assert "missing" in no_url_error.message.lower()
|
||||||
|
assert "url" in no_url_error.message.lower()
|
||||||
|
no_body_error = final_batch.errors.data[3]
|
||||||
|
assert no_body_error.code == "missing_required_parameter"
|
||||||
|
assert no_body_error.line == 4
|
||||||
|
assert "missing" in no_body_error.message.lower()
|
||||||
|
assert "body" in no_body_error.message.lower()
|
||||||
|
|
||||||
|
def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test batch creation with invalid completion window."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "invalid-completion-window",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
for window in ["1h", "48h", "invalid", ""]:
|
||||||
|
with pytest.raises(BadRequestError) as exc_info:
|
||||||
|
openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window=window,
|
||||||
|
)
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
error_msg = str(exc_info.value).lower()
|
||||||
|
assert "error" in error_msg
|
||||||
|
assert "completion_window" in error_msg
|
||||||
|
|
||||||
|
def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test that streaming responses are not supported in batches."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "streaming-test",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"stream": True, # Not supported
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(data=[
|
||||||
|
# BatchError(code='streaming_unsupported',
|
||||||
|
# line=1,
|
||||||
|
# message='Chat Completions: Streaming is not supported in the Batch API.',
|
||||||
|
# param='body.stream')
|
||||||
|
# ], object='list'),
|
||||||
|
# failed_at=1754566965,
|
||||||
|
# ...)
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) == 1
|
||||||
|
error = final_batch.errors.data[0]
|
||||||
|
assert error.code == "streaming_unsupported"
|
||||||
|
assert error.line == 1
|
||||||
|
assert "streaming" in error.message.lower()
|
||||||
|
assert "not supported" in error.message.lower()
|
||||||
|
assert error.param == "body.stream"
|
||||||
|
assert final_batch.failed_at is not None
|
||||||
|
|
||||||
|
def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""
|
||||||
|
Test batch with mixed streaming and non-streaming requests.
|
||||||
|
|
||||||
|
This is distinct from test_batch_streaming_not_supported, which tests a single
|
||||||
|
streaming request, to ensure an otherwise valid batch fails when a single
|
||||||
|
streaming request is included.
|
||||||
|
"""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "valid-non-streaming-request",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello without streaming"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "streaming-request",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello with streaming"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
"stream": True, # Not supported
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(data=[
|
||||||
|
# BatchError(
|
||||||
|
# code='streaming_unsupported',
|
||||||
|
# line=2,
|
||||||
|
# message='Chat Completions: Streaming is not supported in the Batch API.',
|
||||||
|
# param='body.stream')
|
||||||
|
# ], object='list'),
|
||||||
|
# failed_at=1754574442,
|
||||||
|
# ...)
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) == 1
|
||||||
|
error = final_batch.errors.data[0]
|
||||||
|
assert error.code == "streaming_unsupported"
|
||||||
|
assert error.line == 2
|
||||||
|
assert "streaming" in error.message.lower()
|
||||||
|
assert "not supported" in error.message.lower()
|
||||||
|
assert error.param == "body.stream"
|
||||||
|
assert final_batch.failed_at is not None
|
||||||
|
|
||||||
|
def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test batch creation with mismatched endpoint and request URL."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "endpoint-mismatch",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/embeddings", # Different from batch endpoint
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions", # Different from request URL
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(data=[
|
||||||
|
# BatchError(
|
||||||
|
# code='invalid_url',
|
||||||
|
# line=1,
|
||||||
|
# message='The URL provided for this request does not match the batch endpoint.',
|
||||||
|
# param='url')
|
||||||
|
# ], object='list'),
|
||||||
|
# failed_at=1754566972,
|
||||||
|
# ...)
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) == 1
|
||||||
|
error = final_batch.errors.data[0]
|
||||||
|
assert error.line == 1
|
||||||
|
assert error.code == "invalid_url"
|
||||||
|
assert "does not match" in error.message.lower()
|
||||||
|
assert "endpoint" in error.message.lower()
|
||||||
|
assert final_batch.failed_at is not None
|
||||||
|
|
||||||
|
def test_batch_error_handling_invalid_model(self, openai_client, batch_helper):
|
||||||
|
"""Test batch error handling with invalid model."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "invalid-model",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": "nonexistent-model-xyz",
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(data=[
|
||||||
|
# BatchError(code='model_not_found',
|
||||||
|
# line=1,
|
||||||
|
# message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.",
|
||||||
|
# param='body.model')
|
||||||
|
# ], object='list'),
|
||||||
|
# failed_at=1754566978,
|
||||||
|
# ...)
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) == 1
|
||||||
|
error = final_batch.errors.data[0]
|
||||||
|
assert error.line == 1
|
||||||
|
assert error.code == "model_not_found"
|
||||||
|
assert "not supported" in error.message.lower()
|
||||||
|
assert error.param == "body.model"
|
||||||
|
assert final_batch.failed_at is not None
|
||||||
|
|
||||||
|
def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test batch with requests missing required fields in body (model and messages)."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "missing-model",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
# Missing model field
|
||||||
|
"messages": [{"role": "user", "content": "Hello without model"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "missing-messages",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
# Missing messages field
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Batch(...,
|
||||||
|
# status='failed',
|
||||||
|
# errors=Errors(data=[
|
||||||
|
# BatchError(
|
||||||
|
# code='invalid_request',
|
||||||
|
# line=1,
|
||||||
|
# message='Model parameter is required.',
|
||||||
|
# param='body.model'),
|
||||||
|
# BatchError(
|
||||||
|
# code='invalid_request',
|
||||||
|
# line=2,
|
||||||
|
# message='Messages parameter is required.',
|
||||||
|
# param='body.messages')
|
||||||
|
# ], object='list'),
|
||||||
|
# ...)
|
||||||
|
|
||||||
|
assert final_batch.status == "failed"
|
||||||
|
assert final_batch.errors is not None
|
||||||
|
assert len(final_batch.errors.data) == 2
|
||||||
|
|
||||||
|
model_error = final_batch.errors.data[0]
|
||||||
|
assert model_error.line == 1
|
||||||
|
assert "model" in model_error.message.lower()
|
||||||
|
assert model_error.param == "body.model"
|
||||||
|
|
||||||
|
messages_error = final_batch.errors.data[1]
|
||||||
|
assert messages_error.line == 2
|
||||||
|
assert "messages" in messages_error.message.lower()
|
||||||
|
assert messages_error.param == "body.messages"
|
||||||
|
|
||||||
|
assert final_batch.failed_at is not None
|
||||||
|
|
||||||
|
def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Test batch creation with invalid metadata types (like lists)."""
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "invalid-metadata-type",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": text_model_id,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata={
|
||||||
|
"tags": ["tag1", "tag2"], # Invalid type, should be a string
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expecting -
|
||||||
|
# Error code: 400 - {'error':
|
||||||
|
# {'message': "Invalid type for 'metadata.tags': expected a string,
|
||||||
|
# but got an array instead.",
|
||||||
|
# 'type': 'invalid_request_error', 'param': 'metadata.tags',
|
||||||
|
# 'code': 'invalid_type'}}
|
||||||
|
|
||||||
|
error_msg = str(exc_info.value).lower()
|
||||||
|
assert "400" in error_msg
|
||||||
|
assert "tags" in error_msg
|
||||||
|
assert "string" in error_msg
|
753
tests/unit/providers/batches/test_reference.py
Normal file
753
tests/unit/providers/batches/test_reference.py
Normal file
|
@ -0,0 +1,753 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Test suite for the reference implementation of the Batches API.
|
||||||
|
|
||||||
|
The tests are categorized and outlined below, keep this updated:
|
||||||
|
|
||||||
|
- Batch creation with various parameters and validation:
|
||||||
|
* test_create_and_retrieve_batch_success (positive)
|
||||||
|
* test_create_batch_without_metadata (positive)
|
||||||
|
* test_create_batch_completion_window (negative)
|
||||||
|
* test_create_batch_invalid_endpoints (negative)
|
||||||
|
* test_create_batch_invalid_metadata (negative)
|
||||||
|
|
||||||
|
- Batch retrieval and error handling for non-existent batches:
|
||||||
|
* test_retrieve_batch_not_found (negative)
|
||||||
|
|
||||||
|
- Batch cancellation with proper status transitions:
|
||||||
|
* test_cancel_batch_success (positive)
|
||||||
|
* test_cancel_batch_invalid_statuses (negative)
|
||||||
|
* test_cancel_batch_not_found (negative)
|
||||||
|
|
||||||
|
- Batch listing with pagination and filtering:
|
||||||
|
* test_list_batches_empty (positive)
|
||||||
|
* test_list_batches_single_batch (positive)
|
||||||
|
* test_list_batches_multiple_batches (positive)
|
||||||
|
* test_list_batches_with_limit (positive)
|
||||||
|
* test_list_batches_with_pagination (positive)
|
||||||
|
* test_list_batches_invalid_after (negative)
|
||||||
|
|
||||||
|
- Data persistence in the underlying key-value store:
|
||||||
|
* test_kvstore_persistence (positive)
|
||||||
|
|
||||||
|
- Batch processing concurrency control:
|
||||||
|
* test_max_concurrent_batches (positive)
|
||||||
|
|
||||||
|
- Input validation testing (direct _validate_input method tests):
|
||||||
|
* test_validate_input_file_not_found (negative)
|
||||||
|
* test_validate_input_file_exists_empty_content (positive)
|
||||||
|
* test_validate_input_file_mixed_valid_invalid_json (mixed)
|
||||||
|
* test_validate_input_invalid_model (negative)
|
||||||
|
* test_validate_input_url_mismatch (negative)
|
||||||
|
* test_validate_input_multiple_errors_per_request (negative)
|
||||||
|
* test_validate_input_invalid_request_format (negative)
|
||||||
|
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||||
|
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||||
|
|
||||||
|
The tests use temporary SQLite databases for isolation and mock external
|
||||||
|
dependencies like inference, files, and models APIs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.batches import BatchObject
|
||||||
|
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||||
|
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||||
|
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||||
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestReferenceBatchesImpl:
|
||||||
|
"""Test the reference implementation of the Batches API."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def provider(self):
|
||||||
|
"""Create a test provider instance with temporary database."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test_batches.db"
|
||||||
|
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||||
|
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||||
|
|
||||||
|
# Create kvstore and mock APIs
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
|
kvstore = await kvstore_impl(config.kvstore)
|
||||||
|
mock_inference = AsyncMock()
|
||||||
|
mock_files = AsyncMock()
|
||||||
|
mock_models = AsyncMock()
|
||||||
|
|
||||||
|
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
|
||||||
|
await provider.initialize()
|
||||||
|
|
||||||
|
# unit tests should not require background processing
|
||||||
|
provider.process_batches = False
|
||||||
|
|
||||||
|
yield provider
|
||||||
|
|
||||||
|
await provider.shutdown()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_batch_data(self):
|
||||||
|
"""Sample batch data for testing."""
|
||||||
|
return {
|
||||||
|
"input_file_id": "file_abc123",
|
||||||
|
"endpoint": "/v1/chat/completions",
|
||||||
|
"completion_window": "24h",
|
||||||
|
"metadata": {"test": "true", "priority": "high"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _validate_batch_type(self, batch, expected_metadata=None):
|
||||||
|
"""
|
||||||
|
Helper function to validate batch object structure and field types.
|
||||||
|
|
||||||
|
Note: This validates the direct BatchObject from the provider, not the
|
||||||
|
client library response which has a different structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The BatchObject instance to validate.
|
||||||
|
expected_metadata: Optional expected metadata dictionary to validate against.
|
||||||
|
"""
|
||||||
|
assert isinstance(batch.id, str)
|
||||||
|
assert isinstance(batch.completion_window, str)
|
||||||
|
assert isinstance(batch.created_at, int)
|
||||||
|
assert isinstance(batch.endpoint, str)
|
||||||
|
assert isinstance(batch.input_file_id, str)
|
||||||
|
assert batch.object == "batch"
|
||||||
|
assert batch.status in [
|
||||||
|
"validating",
|
||||||
|
"failed",
|
||||||
|
"in_progress",
|
||||||
|
"finalizing",
|
||||||
|
"completed",
|
||||||
|
"expired",
|
||||||
|
"cancelling",
|
||||||
|
"cancelled",
|
||||||
|
]
|
||||||
|
|
||||||
|
if expected_metadata is not None:
|
||||||
|
assert batch.metadata == expected_metadata
|
||||||
|
|
||||||
|
timestamp_fields = [
|
||||||
|
"cancelled_at",
|
||||||
|
"cancelling_at",
|
||||||
|
"completed_at",
|
||||||
|
"expired_at",
|
||||||
|
"expires_at",
|
||||||
|
"failed_at",
|
||||||
|
"finalizing_at",
|
||||||
|
"in_progress_at",
|
||||||
|
]
|
||||||
|
for field in timestamp_fields:
|
||||||
|
field_value = getattr(batch, field, None)
|
||||||
|
if field_value is not None:
|
||||||
|
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
|
||||||
|
|
||||||
|
file_id_fields = ["error_file_id", "output_file_id"]
|
||||||
|
for field in file_id_fields:
|
||||||
|
field_value = getattr(batch, field, None)
|
||||||
|
if field_value is not None:
|
||||||
|
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
|
||||||
|
|
||||||
|
if hasattr(batch, "request_counts") and batch.request_counts is not None:
|
||||||
|
assert isinstance(batch.request_counts.completed, int), (
|
||||||
|
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
|
||||||
|
)
|
||||||
|
assert isinstance(batch.request_counts.failed, int), (
|
||||||
|
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
|
||||||
|
)
|
||||||
|
assert isinstance(batch.request_counts.total, int), (
|
||||||
|
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(batch, "errors") and batch.errors is not None:
|
||||||
|
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
|
||||||
|
|
||||||
|
if hasattr(batch.errors, "data") and batch.errors.data is not None:
|
||||||
|
assert isinstance(batch.errors.data, list), (
|
||||||
|
f"errors.data should be list or None, got {type(batch.errors.data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, error_item in enumerate(batch.errors.data):
|
||||||
|
assert isinstance(error_item, dict), (
|
||||||
|
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(error_item, "code") and error_item.code is not None:
|
||||||
|
assert isinstance(error_item.code, str), (
|
||||||
|
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(error_item, "line") and error_item.line is not None:
|
||||||
|
assert isinstance(error_item.line, int), (
|
||||||
|
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(error_item, "message") and error_item.message is not None:
|
||||||
|
assert isinstance(error_item.message, str), (
|
||||||
|
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(error_item, "param") and error_item.param is not None:
|
||||||
|
assert isinstance(error_item.param, str), (
|
||||||
|
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(batch.errors, "object") and batch.errors.object is not None:
|
||||||
|
assert isinstance(batch.errors.object, str), (
|
||||||
|
f"errors.object should be str or None, got {type(batch.errors.object)}"
|
||||||
|
)
|
||||||
|
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
|
||||||
|
|
||||||
|
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||||
|
"""Test successful batch creation and retrieval."""
|
||||||
|
created_batch = await provider.create_batch(**sample_batch_data)
|
||||||
|
|
||||||
|
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||||
|
|
||||||
|
assert created_batch.id.startswith("batch_")
|
||||||
|
assert len(created_batch.id) > 13
|
||||||
|
assert created_batch.object == "batch"
|
||||||
|
assert created_batch.endpoint == sample_batch_data["endpoint"]
|
||||||
|
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
|
||||||
|
assert created_batch.completion_window == sample_batch_data["completion_window"]
|
||||||
|
assert created_batch.status == "validating"
|
||||||
|
assert created_batch.metadata == sample_batch_data["metadata"]
|
||||||
|
assert isinstance(created_batch.created_at, int)
|
||||||
|
assert created_batch.created_at > 0
|
||||||
|
|
||||||
|
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
||||||
|
|
||||||
|
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||||
|
|
||||||
|
assert retrieved_batch.id == created_batch.id
|
||||||
|
assert retrieved_batch.input_file_id == created_batch.input_file_id
|
||||||
|
assert retrieved_batch.endpoint == created_batch.endpoint
|
||||||
|
assert retrieved_batch.status == created_batch.status
|
||||||
|
assert retrieved_batch.metadata == created_batch.metadata
|
||||||
|
|
||||||
|
async def test_create_batch_without_metadata(self, provider):
|
||||||
|
"""Test batch creation without optional metadata."""
|
||||||
|
batch = await provider.create_batch(
|
||||||
|
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert batch.metadata is None
|
||||||
|
|
||||||
|
async def test_create_batch_completion_window(self, provider):
|
||||||
|
"""Test batch creation with invalid completion window."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid completion_window"):
|
||||||
|
await provider.create_batch(
|
||||||
|
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"endpoint",
|
||||||
|
[
|
||||||
|
"/v1/embeddings",
|
||||||
|
"/v1/completions",
|
||||||
|
"/v1/invalid/endpoint",
|
||||||
|
"",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||||
|
"""Test batch creation with various invalid endpoints."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||||
|
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||||
|
|
||||||
|
async def test_create_batch_invalid_metadata(self, provider):
|
||||||
|
"""Test that batch creation fails with invalid metadata."""
|
||||||
|
with pytest.raises(ValueError, match="should be a valid string"):
|
||||||
|
await provider.create_batch(
|
||||||
|
input_file_id="file_123",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata={123: "invalid_key"}, # Non-string key
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="should be a valid string"):
|
||||||
|
await provider.create_batch(
|
||||||
|
input_file_id="file_123",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata={"valid_key": 456}, # Non-string value
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_retrieve_batch_not_found(self, provider):
|
||||||
|
"""Test error when retrieving non-existent batch."""
|
||||||
|
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||||
|
await provider.retrieve_batch("nonexistent_batch")
|
||||||
|
|
||||||
|
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||||
|
"""Test successful batch cancellation."""
|
||||||
|
created_batch = await provider.create_batch(**sample_batch_data)
|
||||||
|
assert created_batch.status == "validating"
|
||||||
|
|
||||||
|
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
||||||
|
|
||||||
|
assert cancelled_batch.id == created_batch.id
|
||||||
|
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||||
|
assert isinstance(cancelled_batch.cancelling_at, int)
|
||||||
|
assert cancelled_batch.cancelling_at >= created_batch.created_at
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
|
||||||
|
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||||
|
"""Test error when cancelling batch in final states."""
|
||||||
|
provider.process_batches = False
|
||||||
|
created_batch = await provider.create_batch(**sample_batch_data)
|
||||||
|
|
||||||
|
# directly update status in kvstore
|
||||||
|
await provider._update_batch(created_batch.id, status=status)
|
||||||
|
|
||||||
|
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||||
|
await provider.cancel_batch(created_batch.id)
|
||||||
|
|
||||||
|
async def test_cancel_batch_not_found(self, provider):
|
||||||
|
"""Test error when cancelling non-existent batch."""
|
||||||
|
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||||
|
await provider.cancel_batch("nonexistent_batch")
|
||||||
|
|
||||||
|
async def test_list_batches_empty(self, provider):
|
||||||
|
"""Test listing batches when none exist."""
|
||||||
|
response = await provider.list_batches()
|
||||||
|
|
||||||
|
assert response.object == "list"
|
||||||
|
assert response.data == []
|
||||||
|
assert response.first_id is None
|
||||||
|
assert response.last_id is None
|
||||||
|
assert response.has_more is False
|
||||||
|
|
||||||
|
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||||
|
"""Test listing batches with single batch."""
|
||||||
|
created_batch = await provider.create_batch(**sample_batch_data)
|
||||||
|
|
||||||
|
response = await provider.list_batches()
|
||||||
|
|
||||||
|
assert len(response.data) == 1
|
||||||
|
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||||
|
assert response.data[0].id == created_batch.id
|
||||||
|
assert response.first_id == created_batch.id
|
||||||
|
assert response.last_id == created_batch.id
|
||||||
|
assert response.has_more is False
|
||||||
|
|
||||||
|
async def test_list_batches_multiple_batches(self, provider):
|
||||||
|
"""Test listing multiple batches."""
|
||||||
|
batches = [
|
||||||
|
await provider.create_batch(
|
||||||
|
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await provider.list_batches()
|
||||||
|
|
||||||
|
assert len(response.data) == 3
|
||||||
|
|
||||||
|
batch_ids = {batch.id for batch in response.data}
|
||||||
|
expected_ids = {batch.id for batch in batches}
|
||||||
|
assert batch_ids == expected_ids
|
||||||
|
assert response.has_more is False
|
||||||
|
|
||||||
|
assert response.first_id in expected_ids
|
||||||
|
assert response.last_id in expected_ids
|
||||||
|
|
||||||
|
async def test_list_batches_with_limit(self, provider):
|
||||||
|
"""Test listing batches with limit parameter."""
|
||||||
|
batches = [
|
||||||
|
await provider.create_batch(
|
||||||
|
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||||
|
)
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await provider.list_batches(limit=2)
|
||||||
|
|
||||||
|
assert len(response.data) == 2
|
||||||
|
assert response.has_more is True
|
||||||
|
assert response.first_id == response.data[0].id
|
||||||
|
assert response.last_id == response.data[1].id
|
||||||
|
batch_ids = {batch.id for batch in response.data}
|
||||||
|
expected_ids = {batch.id for batch in batches}
|
||||||
|
assert batch_ids.issubset(expected_ids)
|
||||||
|
|
||||||
|
async def test_list_batches_with_pagination(self, provider):
|
||||||
|
"""Test listing batches with pagination using 'after' parameter."""
|
||||||
|
for i in range(3):
|
||||||
|
await provider.create_batch(
|
||||||
|
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get first page
|
||||||
|
first_page = await provider.list_batches(limit=1)
|
||||||
|
assert len(first_page.data) == 1
|
||||||
|
assert first_page.has_more is True
|
||||||
|
|
||||||
|
# Get second page using 'after'
|
||||||
|
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
||||||
|
assert len(second_page.data) == 1
|
||||||
|
assert second_page.data[0].id != first_page.data[0].id
|
||||||
|
|
||||||
|
# Verify we got the next batch in order
|
||||||
|
all_batches = await provider.list_batches()
|
||||||
|
expected_second_batch_id = all_batches.data[1].id
|
||||||
|
assert second_page.data[0].id == expected_second_batch_id
|
||||||
|
|
||||||
|
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||||
|
"""Test listing batches with invalid 'after' parameter."""
|
||||||
|
await provider.create_batch(**sample_batch_data)
|
||||||
|
|
||||||
|
response = await provider.list_batches(after="nonexistent_batch")
|
||||||
|
|
||||||
|
# Should return all batches (no filtering when 'after' batch not found)
|
||||||
|
assert len(response.data) == 1
|
||||||
|
|
||||||
|
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||||
|
"""Test that batches are properly persisted in kvstore."""
|
||||||
|
batch = await provider.create_batch(**sample_batch_data)
|
||||||
|
|
||||||
|
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||||
|
assert stored_data is not None
|
||||||
|
|
||||||
|
stored_batch_dict = json.loads(stored_data)
|
||||||
|
assert stored_batch_dict["id"] == batch.id
|
||||||
|
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
|
||||||
|
|
||||||
|
async def test_validate_input_file_not_found(self, provider):
|
||||||
|
"""Test _validate_input when input file does not exist."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id="nonexistent_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
assert errors[0].code == "invalid_request"
|
||||||
|
assert errors[0].message == "Cannot find file nonexistent_file."
|
||||||
|
assert errors[0].param == "input_file_id"
|
||||||
|
assert errors[0].line is None
|
||||||
|
|
||||||
|
async def test_validate_input_file_exists_empty_content(self, provider):
|
||||||
|
"""Test _validate_input when file exists but is empty."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.body = b""
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id="empty_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 0
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
|
||||||
|
"""Test _validate_input when file contains valid and invalid JSON lines."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
|
||||||
|
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id="mixed_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 1
|
||||||
|
|
||||||
|
assert errors[0].code == "invalid_json_line"
|
||||||
|
assert errors[0].line == 2
|
||||||
|
assert errors[0].message == "This line is not parseable as valid JSON."
|
||||||
|
|
||||||
|
assert requests[0].custom_id == "req-1"
|
||||||
|
assert requests[0].method == "POST"
|
||||||
|
assert requests[0].url == "/v1/chat/completions"
|
||||||
|
assert requests[0].body["model"] == "test-model"
|
||||||
|
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
|
||||||
|
|
||||||
|
async def test_validate_input_invalid_model(self, provider):
|
||||||
|
"""Test _validate_input when file contains request with non-existent model."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id="invalid_model_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
assert errors[0].code == "model_not_found"
|
||||||
|
assert errors[0].line == 1
|
||||||
|
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
|
||||||
|
assert errors[0].param == "body.model"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"param_name,param_path,error_code,error_message",
|
||||||
|
[
|
||||||
|
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||||
|
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||||
|
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||||
|
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||||
|
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||||
|
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||||
|
"""Test _validate_input when file contains request with missing required parameters."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
base_request = {
|
||||||
|
"custom_id": "req-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove the specific parameter being tested
|
||||||
|
if "." in param_path:
|
||||||
|
top_level, nested_param = param_path.split(".", 1)
|
||||||
|
del base_request[top_level][nested_param]
|
||||||
|
else:
|
||||||
|
del base_request[param_name]
|
||||||
|
|
||||||
|
mock_response.body = json.dumps(base_request).encode()
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id=f"missing_{param_name}_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
assert errors[0].code == error_code
|
||||||
|
assert errors[0].line == 1
|
||||||
|
assert errors[0].message == error_message
|
||||||
|
assert errors[0].param == param_path
|
||||||
|
|
||||||
|
async def test_validate_input_url_mismatch(self, provider):
|
||||||
|
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
|
||||||
|
input_file_id="url_mismatch_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
assert errors[0].code == "invalid_url"
|
||||||
|
assert errors[0].line == 1
|
||||||
|
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
|
||||||
|
assert errors[0].param == "url"
|
||||||
|
|
||||||
|
async def test_validate_input_multiple_errors_per_request(self, provider):
|
||||||
|
"""Test _validate_input when a single request has multiple validation errors."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
# Request missing custom_id, has invalid URL, and missing model in body
|
||||||
|
mock_response.body = (
|
||||||
|
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
|
||||||
|
)
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
|
||||||
|
input_file_id="multiple_errors_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
for error in errors:
|
||||||
|
assert error.line == 1
|
||||||
|
|
||||||
|
error_codes = {error.code for error in errors}
|
||||||
|
assert "missing_required_parameter" in error_codes # missing custom_id
|
||||||
|
assert "invalid_url" in error_codes # URL mismatch
|
||||||
|
|
||||||
|
async def test_validate_input_invalid_request_format(self, provider):
|
||||||
|
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.body = b'["not", "a", "request", "object"]'
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id="invalid_format_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
assert errors[0].code == "invalid_request"
|
||||||
|
assert errors[0].line == 1
|
||||||
|
assert errors[0].message == "Each line must be a JSON dictionary object"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"param_name,param_path,invalid_value,error_message",
|
||||||
|
[
|
||||||
|
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
|
||||||
|
("url", "url", 123, "URL must be a string"),
|
||||||
|
("method", "method", ["POST"], "Method must be a string"),
|
||||||
|
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
|
||||||
|
("model", "body.model", 123, "Model must be a string"),
|
||||||
|
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_validate_input_invalid_parameter_types(
|
||||||
|
self, provider, param_name, param_path, invalid_value, error_message
|
||||||
|
):
|
||||||
|
"""Test _validate_input when file contains request with parameters that have invalid types."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
base_request = {
|
||||||
|
"custom_id": "req-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Override the specific parameter with invalid value
|
||||||
|
if "." in param_path:
|
||||||
|
top_level, nested_param = param_path.split(".", 1)
|
||||||
|
base_request[top_level][nested_param] = invalid_value
|
||||||
|
else:
|
||||||
|
base_request[param_name] = invalid_value
|
||||||
|
|
||||||
|
mock_response.body = json.dumps(base_request).encode()
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/chat/completions",
|
||||||
|
input_file_id=f"invalid_{param_name}_type_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
assert errors[0].code == "invalid_request"
|
||||||
|
assert errors[0].line == 1
|
||||||
|
assert errors[0].message == error_message
|
||||||
|
assert errors[0].param == param_path
|
||||||
|
|
||||||
|
async def test_max_concurrent_batches(self, provider):
|
||||||
|
"""Test max_concurrent_batches configuration and concurrency control."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
provider._batch_semaphore = asyncio.Semaphore(2)
|
||||||
|
|
||||||
|
provider.process_batches = True # enable because we're testing background processing
|
||||||
|
|
||||||
|
active_batches = 0
|
||||||
|
|
||||||
|
async def add_and_wait(batch_id: str):
|
||||||
|
nonlocal active_batches
|
||||||
|
active_batches += 1
|
||||||
|
await asyncio.sleep(float("inf"))
|
||||||
|
|
||||||
|
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
|
||||||
|
# so we can replace _process_batch_impl with our mock to control concurrency
|
||||||
|
provider._process_batch_impl = add_and_wait
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
await provider.create_batch(
|
||||||
|
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.042) # let tasks start
|
||||||
|
|
||||||
|
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
Loading…
Add table
Add a link
Reference in a new issue