mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-15 06:00:48 +00:00
feat: add batches API with OpenAI compatibility (#3088)
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / discover-tests (push) Successful in 12s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 15s
Python Package Build Test / build (3.12) (push) Failing after 16s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 23s
Python Package Build Test / build (3.13) (push) Failing after 17s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 29s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 25s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 29s
Unit Tests / unit-tests (3.12) (push) Failing after 20s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 12s
Test External API and Providers / test-external (venv) (push) Failing after 22s
Unit Tests / unit-tests (3.13) (push) Failing after 18s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 24s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 24s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 24s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 24s
Update ReadTheDocs / update-readthedocs (push) Failing after 38s
Pre-commit / pre-commit (push) Successful in 1m53s
Some checks failed
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Integration Tests (Replay) / discover-tests (push) Successful in 12s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 15s
Python Package Build Test / build (3.12) (push) Failing after 16s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 23s
Python Package Build Test / build (3.13) (push) Failing after 17s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 29s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 21s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 25s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 28s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 29s
Unit Tests / unit-tests (3.12) (push) Failing after 20s
Integration Tests (Replay) / Integration Tests (, , , client=, vision=) (push) Failing after 12s
Test External API and Providers / test-external (venv) (push) Failing after 22s
Unit Tests / unit-tests (3.13) (push) Failing after 18s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.12, remote::qdrant) (push) Failing after 24s
Vector IO Integration Tests / test-matrix (3.12, remote::weaviate) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 24s
Vector IO Integration Tests / test-matrix (3.13, remote::weaviate) (push) Failing after 23s
Vector IO Integration Tests / test-matrix (3.13, remote::qdrant) (push) Failing after 24s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 25s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 27s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 24s
Update ReadTheDocs / update-readthedocs (push) Failing after 38s
Pre-commit / pre-commit (push) Successful in 1m53s
Add complete batches API implementation with protocol, providers, and tests: Core Infrastructure: - Add batches API protocol using OpenAI Batch types directly - Add Api.batches enum value and protocol mapping in resolver - Add OpenAI "batch" file purpose support - Include proper error handling (ConflictError, ResourceNotFoundError) Reference Provider: - Add ReferenceBatchesImpl with full CRUD operations (create, retrieve, cancel, list) - Implement background batch processing with configurable concurrency - Add SQLite KVStore backend for persistence - Support /v1/chat/completions endpoint with request validation Comprehensive Test Suite: - Add unit tests for provider implementation with validation - Add integration tests for end-to-end batch processing workflows - Add error handling tests for validation, malformed inputs, and edge cases Configuration: - Add max_concurrent_batches and max_concurrent_requests_per_batch options - Add provider documentation with sample configurations Test with - ``` $ uv run llama stack build --image-type venv --providers inference=YOU_PICK,files=inline::localfs,batches=inline::reference --run & $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest tests/unit/providers/batches tests/integration/batches --text-model YOU_PICK ``` addresses #3066
This commit is contained in:
parent
46ff302d87
commit
de692162af
26 changed files with 2707 additions and 2 deletions
6
docs/_static/llama-stack-spec.html
vendored
6
docs/_static/llama-stack-spec.html
vendored
|
@ -14767,7 +14767,8 @@
|
|||
"OpenAIFilePurpose": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"assistants"
|
||||
"assistants",
|
||||
"batch"
|
||||
],
|
||||
"title": "OpenAIFilePurpose",
|
||||
"description": "Valid purpose values for OpenAI Files API."
|
||||
|
@ -14844,7 +14845,8 @@
|
|||
"purpose": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"assistants"
|
||||
"assistants",
|
||||
"batch"
|
||||
],
|
||||
"description": "The intended purpose of the file"
|
||||
}
|
||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -10951,6 +10951,7 @@ components:
|
|||
type: string
|
||||
enum:
|
||||
- assistants
|
||||
- batch
|
||||
title: OpenAIFilePurpose
|
||||
description: >-
|
||||
Valid purpose values for OpenAI Files API.
|
||||
|
@ -11019,6 +11020,7 @@ components:
|
|||
type: string
|
||||
enum:
|
||||
- assistants
|
||||
- batch
|
||||
description: The intended purpose of the file
|
||||
additionalProperties: false
|
||||
required:
|
||||
|
|
|
@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle.
|
|||
- **Batch Inference**: run inference on a dataset of inputs
|
||||
- **Batch Agents**: run agents on a dataset of inputs
|
||||
- **Synthetic Data Generation**: generate synthetic data for model development
|
||||
- **Batches**: OpenAI-compatible batch management for inference
|
||||
|
|
|
@ -2,6 +2,15 @@
|
|||
|
||||
## Overview
|
||||
|
||||
Agents API for creating and interacting with agentic systems.
|
||||
|
||||
Main functionalities provided by this API:
|
||||
- Create agents with specific instructions and ability to use tools.
|
||||
- Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
|
||||
- Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
|
||||
- Agents can be provided with various shields (see the Safety API for more details).
|
||||
- Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
|
||||
|
||||
This section contains documentation for all available providers for the **agents** API.
|
||||
|
||||
## Providers
|
||||
|
|
21
docs/source/providers/batches/index.md
Normal file
21
docs/source/providers/batches/index.md
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Batches
|
||||
|
||||
## Overview
|
||||
|
||||
Protocol for batch processing API operations.
|
||||
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
|
||||
This section contains documentation for all available providers for the **batches** API.
|
||||
|
||||
## Providers
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
inline_reference
|
||||
```
|
23
docs/source/providers/batches/inline_reference.md
Normal file
23
docs/source/providers/batches/inline_reference.md
Normal file
|
@ -0,0 +1,23 @@
|
|||
# inline::reference
|
||||
|
||||
## Description
|
||||
|
||||
Reference implementation of batches API with KVStore persistence.
|
||||
|
||||
## Configuration
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
|
||||
| `max_concurrent_batches` | `<class 'int'>` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
|
||||
| `max_concurrent_requests_per_batch` | `<class 'int'>` | No | 10 | Maximum number of concurrent requests to process per batch. |
|
||||
|
||||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
kvstore:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
|
||||
|
||||
```
|
||||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
## Overview
|
||||
|
||||
Llama Stack Evaluation API for running evaluations on model and agent candidates.
|
||||
|
||||
This section contains documentation for all available providers for the **eval** API.
|
||||
|
||||
## Providers
|
||||
|
|
|
@ -2,6 +2,12 @@
|
|||
|
||||
## Overview
|
||||
|
||||
Llama Stack Inference API for generating completions, chat completions, and embeddings.
|
||||
|
||||
This API provides the raw interface to the underlying models. Two kinds of models are supported:
|
||||
- LLM models: these models generate "raw" and "chat" (conversational) completions.
|
||||
- Embedding models: these models generate embeddings to be used for semantic search.
|
||||
|
||||
This section contains documentation for all available providers for the **inference** API.
|
||||
|
||||
## Providers
|
||||
|
|
9
llama_stack/apis/batches/__init__.py
Normal file
9
llama_stack/apis/batches/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .batches import Batches, BatchObject, ListBatchesResponse
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
89
llama_stack/apis/batches/batches.py
Normal file
89
llama_stack/apis/batches/batches.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
except ImportError as e:
|
||||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[BatchObject] = Field(..., description="List of batch objects")
|
||||
first_id: str | None = Field(default=None, description="ID of the first batch in the list")
|
||||
last_id: str | None = Field(default=None, description="ID of the last batch in the list")
|
||||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Batches(Protocol):
|
||||
"""Protocol for batch processing API operations.
|
||||
|
||||
The Batches API enables efficient processing of multiple requests in a single operation,
|
||||
particularly useful for processing large datasets, batch evaluation workflows, and
|
||||
cost-effective inference at scale.
|
||||
|
||||
Note: This API is currently under active development and may undergo changes.
|
||||
"""
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="POST")
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> BatchObject:
|
||||
"""Create a new batch for processing multiple API requests.
|
||||
|
||||
:param input_file_id: The ID of an uploaded file containing requests for the batch.
|
||||
:param endpoint: The endpoint to be used for all requests in the batch.
|
||||
:param completion_window: The time window within which the batch should be processed.
|
||||
:param metadata: Optional metadata for the batch.
|
||||
:returns: The created batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch.
|
||||
|
||||
:param batch_id: The ID of the batch to retrieve.
|
||||
:returns: The batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress.
|
||||
|
||||
:param batch_id: The ID of the batch to cancel.
|
||||
:returns: The updated batch object.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/openai/v1/batches", method="GET")
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""List all batches for the current user.
|
||||
|
||||
:param after: A cursor for pagination; returns batches after this batch ID.
|
||||
:param limit: Number of batches to return (default 20, max 100).
|
||||
:returns: A list of batch objects.
|
||||
"""
|
||||
...
|
|
@ -64,6 +64,12 @@ class SessionNotFoundError(ValueError):
|
|||
super().__init__(message)
|
||||
|
||||
|
||||
class ConflictError(ValueError):
|
||||
"""raised when an operation cannot be performed due to a conflict with the current state"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelTypeError(TypeError):
|
||||
"""raised when a model is present but not the correct type"""
|
||||
|
||||
|
|
|
@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
:cvar inference: Text generation, chat completions, and embeddings
|
||||
:cvar safety: Content moderation and safety shields
|
||||
:cvar agents: Agent orchestration and execution
|
||||
:cvar batches: Batch processing for asynchronous API requests
|
||||
:cvar vector_io: Vector database operations and queries
|
||||
:cvar datasetio: Dataset input/output operations
|
||||
:cvar scoring: Model output evaluation and scoring
|
||||
|
@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta):
|
|||
inference = "inference"
|
||||
safety = "safety"
|
||||
agents = "agents"
|
||||
batches = "batches"
|
||||
vector_io = "vector_io"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
|
|
|
@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum):
|
|||
"""
|
||||
|
||||
ASSISTANTS = "assistants"
|
||||
BATCH = "batch"
|
||||
# TODO: Add other purposes as needed
|
||||
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import inspect
|
|||
from typing import Any
|
||||
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.batches import Batches
|
||||
from llama_stack.apis.benchmarks import Benchmarks
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
|
|||
Api.agents: Agents,
|
||||
Api.inference: Inference,
|
||||
Api.inspect: Inspect,
|
||||
Api.batches: Batches,
|
||||
Api.vector_io: VectorIO,
|
||||
Api.vector_dbs: VectorDBs,
|
||||
Api.models: Models,
|
||||
|
|
|
@ -32,6 +32,7 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
|||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
|
@ -128,6 +129,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
|||
]
|
||||
},
|
||||
)
|
||||
elif isinstance(exc, ConflictError):
|
||||
return HTTPException(status_code=409, detail=str(exc))
|
||||
elif isinstance(exc, ResourceNotFoundError):
|
||||
return HTTPException(status_code=404, detail=str(exc))
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
|
|
5
llama_stack/providers/inline/batches/__init__.py
Normal file
5
llama_stack/providers/inline/batches/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
36
llama_stack/providers/inline/batches/reference/__init__.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from .batches import ReferenceBatchesImpl
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
inference_api: Inference | None = deps.get(Api.inference)
|
||||
files_api: Files | None = deps.get(Api.files)
|
||||
models_api: Models | None = deps.get(Api.models)
|
||||
|
||||
if inference_api is None:
|
||||
raise ValueError("Inference API is required but not provided in dependencies")
|
||||
if files_api is None:
|
||||
raise ValueError("Files API is required but not provided in dependencies")
|
||||
if models_api is None:
|
||||
raise ValueError("Models API is required but not provided in dependencies")
|
||||
|
||||
impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
|
||||
await impl.initialize()
|
||||
return impl
|
553
llama_stack/providers/inline/batches/reference/batches.py
Normal file
553
llama_stack/providers/inline/batches/reference/batches.py
Normal file
|
@ -0,0 +1,553 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Any, Literal
|
||||
|
||||
from openai.types.batch import BatchError, Errors
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
from .config import ReferenceBatchesImplConfig
|
||||
|
||||
BATCH_PREFIX = "batch:"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncBytesIO:
|
||||
"""
|
||||
Async-compatible BytesIO wrapper to allow async file-like operations.
|
||||
|
||||
We use this when uploading files to the Files API, as it expects an
|
||||
async file-like object.
|
||||
"""
|
||||
|
||||
def __init__(self, data: bytes):
|
||||
self._buffer = BytesIO(data)
|
||||
|
||||
async def read(self, n=-1):
|
||||
return self._buffer.read(n)
|
||||
|
||||
async def seek(self, pos, whence=0):
|
||||
return self._buffer.seek(pos, whence)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self._buffer.close()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._buffer, name)
|
||||
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
line_num: int
|
||||
custom_id: str
|
||||
method: str
|
||||
url: str
|
||||
body: dict[str, Any]
|
||||
|
||||
|
||||
class ReferenceBatchesImpl(Batches):
|
||||
"""Reference implementation of the Batches API.
|
||||
|
||||
This implementation processes batch files by making individual requests
|
||||
to the inference API and generates output files with results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReferenceBatchesImplConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files,
|
||||
models_api: Models,
|
||||
kvstore: KVStore,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.kvstore = kvstore
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.models_api = models_api
|
||||
self._processing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
|
||||
self._update_batch_lock = asyncio.Lock()
|
||||
|
||||
# this is to allow tests to disable background processing
|
||||
self.process_batches = True
|
||||
|
||||
async def initialize(self) -> None:
|
||||
# TODO: start background processing of existing tasks
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown the batches provider."""
|
||||
if self._processing_tasks:
|
||||
# don't cancel tasks - just let them stop naturally on shutdown
|
||||
# cancelling would mark batches as "cancelled" in the database
|
||||
logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
|
||||
|
||||
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> BatchObject:
|
||||
"""
|
||||
Create a new batch for processing multiple API requests.
|
||||
|
||||
Error handling by levels -
|
||||
0. Input param handling, results in 40x errors before processing, e.g.
|
||||
- Wrong completion_window
|
||||
- Invalid metadata types
|
||||
- Unknown endpoint
|
||||
-> no batch created
|
||||
1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
|
||||
- input_file_id missing
|
||||
- invalid json in file
|
||||
- missing custom_id, method, url, body
|
||||
- invalid model
|
||||
- streaming
|
||||
-> batch created, validation sends to failed status
|
||||
2. Processing errors, result in error_file_id entries, e.g.
|
||||
- Any error returned from inference endpoint
|
||||
-> batch created, goes to completed status
|
||||
"""
|
||||
|
||||
# TODO: set expiration time for garbage collection
|
||||
|
||||
if endpoint not in ["/v1/chat/completions"]:
|
||||
raise ValueError(
|
||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
||||
)
|
||||
|
||||
if completion_window != "24h":
|
||||
raise ValueError(
|
||||
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
|
||||
)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
|
||||
current_time = int(time.time())
|
||||
|
||||
batch = BatchObject(
|
||||
id=batch_id,
|
||||
object="batch",
|
||||
endpoint=endpoint,
|
||||
input_file_id=input_file_id,
|
||||
completion_window=completion_window,
|
||||
status="validating",
|
||||
created_at=current_time,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
|
||||
|
||||
if self.process_batches:
|
||||
task = asyncio.create_task(self._process_batch(batch_id))
|
||||
self._processing_tasks[batch_id] = task
|
||||
|
||||
return batch
|
||||
|
||||
async def cancel_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Cancel a batch that is in progress."""
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
if batch.status in ["cancelled", "cancelling"]:
|
||||
return batch
|
||||
|
||||
if batch.status in ["completed", "failed", "expired"]:
|
||||
raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
|
||||
|
||||
await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
|
||||
|
||||
if batch_id in self._processing_tasks:
|
||||
self._processing_tasks[batch_id].cancel()
|
||||
# note: task removal and status="cancelled" handled in finally block of _process_batch
|
||||
|
||||
return await self.retrieve_batch(batch_id)
|
||||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
"""
|
||||
List all batches, eventually only for the current user.
|
||||
|
||||
With no notion of user, we return all batches.
|
||||
"""
|
||||
batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
|
||||
|
||||
batches = []
|
||||
for batch_data in batch_values:
|
||||
if batch_data:
|
||||
batches.append(BatchObject.model_validate_json(batch_data))
|
||||
|
||||
batches.sort(key=lambda b: b.created_at, reverse=True)
|
||||
|
||||
start_idx = 0
|
||||
if after:
|
||||
for i, batch in enumerate(batches):
|
||||
if batch.id == after:
|
||||
start_idx = i + 1
|
||||
break
|
||||
|
||||
page_batches = batches[start_idx : start_idx + limit]
|
||||
has_more = (start_idx + limit) < len(batches)
|
||||
|
||||
first_id = page_batches[0].id if page_batches else None
|
||||
last_id = page_batches[-1].id if page_batches else None
|
||||
|
||||
return ListBatchesResponse(
|
||||
data=page_batches,
|
||||
first_id=first_id,
|
||||
last_id=last_id,
|
||||
has_more=has_more,
|
||||
)
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject:
|
||||
"""Retrieve information about a specific batch."""
|
||||
batch_data = await self.kvstore.get(f"batch:{batch_id}")
|
||||
if not batch_data:
|
||||
raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
|
||||
|
||||
return BatchObject.model_validate_json(batch_data)
|
||||
|
||||
async def _update_batch(self, batch_id: str, **updates) -> None:
|
||||
"""Update batch fields in kvstore."""
|
||||
async with self._update_batch_lock:
|
||||
try:
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
# batch processing is async. once cancelling, only allow "cancelled" status updates
|
||||
if batch.status == "cancelling" and updates.get("status") != "cancelled":
|
||||
logger.info(
|
||||
f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
|
||||
)
|
||||
return
|
||||
|
||||
if "errors" in updates:
|
||||
updates["errors"] = updates["errors"].model_dump()
|
||||
|
||||
batch_dict = batch.model_dump()
|
||||
batch_dict.update(updates)
|
||||
|
||||
await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update batch {batch_id}: {e}")
|
||||
|
||||
async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
|
||||
"""
|
||||
Read & validate input, return errors and valid input.
|
||||
|
||||
Validation of
|
||||
- input_file_id existance
|
||||
- valid json
|
||||
- custom_id, method, url, body presence and valid
|
||||
- no streaming
|
||||
"""
|
||||
requests: list[BatchRequest] = []
|
||||
errors: list[BatchError] = []
|
||||
try:
|
||||
await self.files_api.openai_retrieve_file(batch.input_file_id)
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=None,
|
||||
message=f"Cannot find file {batch.input_file_id}.",
|
||||
param="input_file_id",
|
||||
)
|
||||
)
|
||||
return errors, requests
|
||||
|
||||
# TODO(SECURITY): do something about large files
|
||||
file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
|
||||
file_content = file_content_response.body.decode("utf-8")
|
||||
for line_num, line in enumerate(file_content.strip().split("\n"), 1):
|
||||
if line.strip(): # skip empty lines
|
||||
try:
|
||||
request = json.loads(line)
|
||||
|
||||
if not isinstance(request, dict):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message="Each line must be a JSON dictionary object",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
valid = True
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("custom_id", str, "string"),
|
||||
("method", str, "string"),
|
||||
("url", str, "string"),
|
||||
("body", dict, "JSON dictionary object"),
|
||||
]:
|
||||
if param not in request:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="missing_required_parameter",
|
||||
line=line_num,
|
||||
message=f"Missing required parameter: {param}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(request[param], expected_type):
|
||||
param_name = "URL" if param == "url" else param.capitalize()
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param_name} must be a {type_string}",
|
||||
param=param,
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_url",
|
||||
line=line_num,
|
||||
message="URL provided for this request does not match the batch endpoint",
|
||||
param="url",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if (body := request.get("body")) and isinstance(body, dict):
|
||||
if body.get("stream", False):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="streaming_unsupported",
|
||||
line=line_num,
|
||||
message="Streaming is not supported in batch processing",
|
||||
param="body.stream",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
for param, expected_type, type_string in [
|
||||
("model", str, "a string"),
|
||||
# messages is specific to /v1/chat/completions
|
||||
# we could skip validating messages here and let inference fail. however,
|
||||
# that would be a very expensive way to find out messages is wrong.
|
||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||
]:
|
||||
if param not in body:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} parameter is required",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
elif not isinstance(body[param], expected_type):
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_request",
|
||||
line=line_num,
|
||||
message=f"{param.capitalize()} must be {type_string}",
|
||||
param=f"body.{param}",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if "model" in body and isinstance(body["model"], str):
|
||||
try:
|
||||
await self.models_api.get_model(body["model"])
|
||||
except Exception:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="model_not_found",
|
||||
line=line_num,
|
||||
message=f"Model '{body['model']}' does not exist or is not supported",
|
||||
param="body.model",
|
||||
)
|
||||
)
|
||||
valid = False
|
||||
|
||||
if valid:
|
||||
assert isinstance(url, str), "URL must be a string" # for mypy
|
||||
assert isinstance(body, dict), "Body must be a dictionary" # for mypy
|
||||
requests.append(
|
||||
BatchRequest(
|
||||
line_num=line_num,
|
||||
url=url,
|
||||
method=request["method"],
|
||||
custom_id=request["custom_id"],
|
||||
body=body,
|
||||
),
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
errors.append(
|
||||
BatchError(
|
||||
code="invalid_json_line",
|
||||
line=line_num,
|
||||
message="This line is not parseable as valid JSON.",
|
||||
)
|
||||
)
|
||||
|
||||
return errors, requests
|
||||
|
||||
async def _process_batch(self, batch_id: str) -> None:
|
||||
"""Background task to process a batch of requests."""
|
||||
try:
|
||||
logger.info(f"Starting batch processing for {batch_id}")
|
||||
async with self._batch_semaphore: # semaphore to limit concurrency
|
||||
logger.info(f"Acquired semaphore for batch {batch_id}")
|
||||
await self._process_batch_impl(batch_id)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Batch processing cancelled for {batch_id}")
|
||||
await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
|
||||
except Exception as e:
|
||||
logger.error(f"Batch processing failed for {batch_id}: {e}")
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
|
||||
)
|
||||
finally:
|
||||
self._processing_tasks.pop(batch_id, None)
|
||||
|
||||
async def _process_batch_impl(self, batch_id: str) -> None:
|
||||
"""Implementation of batch processing logic."""
|
||||
errors: list[BatchError] = []
|
||||
batch = await self.retrieve_batch(batch_id)
|
||||
|
||||
errors, requests = await self._validate_input(batch)
|
||||
if errors:
|
||||
await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
|
||||
logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
|
||||
return
|
||||
|
||||
logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
|
||||
|
||||
total_requests = len(requests)
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="in_progress",
|
||||
request_counts={"total": total_requests, "completed": 0, "failed": 0},
|
||||
)
|
||||
|
||||
error_results = []
|
||||
success_results = []
|
||||
completed_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
|
||||
# we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
|
||||
|
||||
chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
|
||||
|
||||
for result in chunk_results:
|
||||
if isinstance(result, dict) and result.get("error") is not None: # error response from inference
|
||||
failed_count += 1
|
||||
error_results.append(result)
|
||||
elif isinstance(result, dict) and result.get("response") is not None: # successful inference
|
||||
completed_count += 1
|
||||
success_results.append(result)
|
||||
else: # unexpected result
|
||||
failed_count += 1
|
||||
errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
|
||||
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
|
||||
)
|
||||
|
||||
if errors:
|
||||
await self._update_batch(
|
||||
batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
output_file_id = await self._create_output_file(batch_id, success_results, "success")
|
||||
await self._update_batch(batch_id, output_file_id=output_file_id)
|
||||
|
||||
error_file_id = await self._create_output_file(batch_id, error_results, "error")
|
||||
await self._update_batch(batch_id, error_file_id=error_file_id)
|
||||
|
||||
await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
|
||||
|
||||
logger.info(
|
||||
f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
|
||||
)
|
||||
except Exception as e:
|
||||
# note: errors is empty at this point, so we don't lose anything by ignoring it
|
||||
await self._update_batch(
|
||||
batch_id,
|
||||
status="failed",
|
||||
failed_at=int(time.time()),
|
||||
errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
|
||||
)
|
||||
|
||||
async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
|
||||
"""Process a single request from the batch."""
|
||||
request_id = f"batch_req_{batch_id}_{request.line_num}"
|
||||
|
||||
try:
|
||||
# TODO(SECURITY): review body for security issues
|
||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||
|
||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": request_id, # TODO: should this be different?
|
||||
"body": chat_response.model_dump_json(),
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||
return {
|
||||
"id": request_id,
|
||||
"custom_id": request.custom_id,
|
||||
"error": {"type": "request_failed", "message": str(e)},
|
||||
}
|
||||
|
||||
async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
|
||||
"""
|
||||
Create an output file with batch results.
|
||||
|
||||
This function filters results based on the specified file_type
|
||||
and uploads the file to the Files API.
|
||||
"""
|
||||
output_lines = [json.dumps(result) for result in results]
|
||||
|
||||
with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
|
||||
file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
|
||||
uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||
return uploaded_file.id
|
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
40
llama_stack/providers/inline/batches/reference/config.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
|
||||
|
||||
class ReferenceBatchesImplConfig(BaseModel):
|
||||
"""Configuration for the Reference Batches implementation."""
|
||||
|
||||
kvstore: KVStoreConfig = Field(
|
||||
description="Configuration for the key-value store backend.",
|
||||
)
|
||||
|
||||
max_concurrent_batches: int = Field(
|
||||
default=1,
|
||||
description="Maximum number of concurrent batches to process simultaneously.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
max_concurrent_requests_per_batch: int = Field(
|
||||
default=10,
|
||||
description="Maximum number of concurrent requests to process per batch.",
|
||||
ge=1,
|
||||
)
|
||||
|
||||
# TODO: add a max requests per second rate limiter
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str) -> dict:
|
||||
return {
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="batches.db",
|
||||
),
|
||||
}
|
26
llama_stack/providers/registry/batches.py
Normal file
26
llama_stack/providers/registry/batches.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.batches,
|
||||
provider_type="inline::reference",
|
||||
pip_packages=["openai"],
|
||||
module="llama_stack.providers.inline.batches.reference",
|
||||
config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
|
||||
api_dependencies=[
|
||||
Api.inference,
|
||||
Api.files,
|
||||
Api.models,
|
||||
],
|
||||
description="Reference implementation of batches API with KVStore persistence.",
|
||||
),
|
||||
]
|
|
@ -18,6 +18,23 @@ from llama_stack.core.distribution import get_provider_registry
|
|||
REPO_ROOT = Path(__file__).parent.parent
|
||||
|
||||
|
||||
def get_api_docstring(api_name: str) -> str | None:
|
||||
"""Extract docstring from the API protocol class."""
|
||||
try:
|
||||
# Import the API module dynamically
|
||||
api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()])
|
||||
|
||||
# Get the main protocol class (usually capitalized API name)
|
||||
protocol_class_name = api_name.title()
|
||||
if hasattr(api_module, protocol_class_name):
|
||||
protocol_class = getattr(api_module, protocol_class_name)
|
||||
return protocol_class.__doc__
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ChangedPathTracker:
|
||||
"""Track a list of paths we may have changed."""
|
||||
|
||||
|
@ -261,6 +278,11 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
|
|||
index_content.append(f"# {api_name.title()}\n")
|
||||
index_content.append("## Overview\n")
|
||||
|
||||
api_docstring = get_api_docstring(api_name)
|
||||
if api_docstring:
|
||||
cleaned_docstring = api_docstring.strip()
|
||||
index_content.append(f"{cleaned_docstring}\n")
|
||||
|
||||
index_content.append(
|
||||
f"This section contains documentation for all available providers for the **{api_name}** API.\n"
|
||||
)
|
||||
|
|
5
tests/integration/batches/__init__.py
Normal file
5
tests/integration/batches/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
122
tests/integration/batches/conftest.py
Normal file
122
tests/integration/batches/conftest.py
Normal file
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Shared pytest fixtures for batch tests."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.files import OpenAIFilePurpose
|
||||
|
||||
|
||||
class BatchHelper:
|
||||
"""Helper class for creating and managing batch input files."""
|
||||
|
||||
def __init__(self, client):
|
||||
"""Initialize with either a batch_client or openai_client."""
|
||||
self.client = client
|
||||
|
||||
@contextmanager
|
||||
def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
|
||||
"""Context manager for creating and cleaning up batch input files.
|
||||
|
||||
Args:
|
||||
content: Either a list of batch request dictionaries or raw string content
|
||||
filename_prefix: Prefix for the generated filename (or full filename if content is string)
|
||||
|
||||
Yields:
|
||||
The uploaded file object
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
# Handle raw string content (e.g., malformed JSONL, empty files)
|
||||
file_content = content.encode("utf-8")
|
||||
else:
|
||||
# Handle list of batch request dictionaries
|
||||
jsonl_content = "\n".join(json.dumps(req) for req in content)
|
||||
file_content = jsonl_content.encode("utf-8")
|
||||
|
||||
filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
|
||||
|
||||
with BytesIO(file_content) as file_buffer:
|
||||
file_buffer.name = filename
|
||||
uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
|
||||
|
||||
try:
|
||||
yield uploaded_file
|
||||
finally:
|
||||
try:
|
||||
self.client.files.delete(uploaded_file.id)
|
||||
except Exception:
|
||||
warnings.warn(
|
||||
f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def wait_for(
|
||||
self,
|
||||
batch_id: str,
|
||||
max_wait_time: int = 60,
|
||||
sleep_interval: int | None = None,
|
||||
expected_statuses: set[str] | None = None,
|
||||
timeout_action: str = "fail",
|
||||
):
|
||||
"""Wait for a batch to reach a terminal status.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to monitor
|
||||
max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
|
||||
sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
|
||||
expected_statuses: Set of expected terminal statuses (default: {"completed"})
|
||||
timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
|
||||
|
||||
Returns:
|
||||
The final batch object
|
||||
|
||||
Raises:
|
||||
pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
|
||||
pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
|
||||
"""
|
||||
if sleep_interval is None:
|
||||
# Default to 1/10th of max_wait_time, with min 1s and max 15s
|
||||
sleep_interval = max(1, min(15, max_wait_time // 10))
|
||||
|
||||
if expected_statuses is None:
|
||||
expected_statuses = {"completed"}
|
||||
|
||||
terminal_statuses = {"completed", "failed", "cancelled", "expired"}
|
||||
unexpected_statuses = terminal_statuses - expected_statuses
|
||||
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < max_wait_time:
|
||||
current_batch = self.client.batches.retrieve(batch_id)
|
||||
|
||||
if current_batch.status in expected_statuses:
|
||||
return current_batch
|
||||
elif current_batch.status in unexpected_statuses:
|
||||
error_msg = f"Batch reached unexpected status: {current_batch.status}"
|
||||
if timeout_action == "skip":
|
||||
pytest.skip(error_msg)
|
||||
else:
|
||||
pytest.fail(error_msg)
|
||||
|
||||
time.sleep(sleep_interval)
|
||||
|
||||
timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
|
||||
if timeout_action == "skip":
|
||||
pytest.skip(timeout_msg)
|
||||
else:
|
||||
pytest.fail(timeout_msg)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_helper(openai_client):
|
||||
"""Fixture that provides a BatchHelper instance for OpenAI client."""
|
||||
return BatchHelper(openai_client)
|
270
tests/integration/batches/test_batches.py
Normal file
270
tests/integration/batches/test_batches.py
Normal file
|
@ -0,0 +1,270 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Integration tests for the Llama Stack batch processing functionality.
|
||||
|
||||
This module contains comprehensive integration tests for the batch processing API,
|
||||
using the OpenAI-compatible client interface for consistency.
|
||||
|
||||
Test Categories:
|
||||
1. Core Batch Operations:
|
||||
- test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval
|
||||
- test_batch_listing: Basic batch listing functionality
|
||||
- test_batch_immediate_cancellation: Batch cancellation workflow
|
||||
# TODO: cancel during processing
|
||||
|
||||
2. End-to-End Processing:
|
||||
- test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation
|
||||
|
||||
Note: Error conditions and edge cases are primarily tested in test_batches_errors.py
|
||||
for better organization and separation of concerns.
|
||||
|
||||
CLEANUP WARNING: These tests currently create batches that are not automatically
|
||||
cleaned up after test completion. This may lead to resource accumulation over
|
||||
multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch.
|
||||
The test_batch_e2e_chat_completions test does clean up its output and error files.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
class TestBatchesIntegration:
|
||||
"""Integration tests for the batches API."""
|
||||
|
||||
def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test comprehensive batch creation and retrieval scenarios."""
|
||||
test_metadata = {
|
||||
"test_type": "comprehensive",
|
||||
"purpose": "creation_and_retrieval_test",
|
||||
"version": "1.0",
|
||||
"tags": "test,batch",
|
||||
}
|
||||
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata=test_metadata,
|
||||
)
|
||||
|
||||
assert batch.endpoint == "/v1/chat/completions"
|
||||
assert batch.input_file_id == uploaded_file.id
|
||||
assert batch.completion_window == "24h"
|
||||
assert batch.metadata == test_metadata
|
||||
|
||||
retrieved_batch = openai_client.batches.retrieve(batch.id)
|
||||
|
||||
assert retrieved_batch.id == batch.id
|
||||
assert retrieved_batch.object == batch.object
|
||||
assert retrieved_batch.endpoint == batch.endpoint
|
||||
assert retrieved_batch.input_file_id == batch.input_file_id
|
||||
assert retrieved_batch.completion_window == batch.completion_window
|
||||
assert retrieved_batch.metadata == batch.metadata
|
||||
|
||||
def test_batch_listing(self, openai_client, batch_helper, text_model_id):
|
||||
"""
|
||||
Test batch listing.
|
||||
|
||||
This test creates multiple batches and verifies that they can be listed.
|
||||
It also deletes the input files before execution, which means the batches
|
||||
will appear as failed due to missing input files. This is expected and
|
||||
a good thing, because it means no inference is performed.
|
||||
"""
|
||||
batch_ids = []
|
||||
|
||||
for i in range(2):
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": f"request-{i}",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": f"Hello {i}"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
batch_ids.append(batch.id)
|
||||
|
||||
batch_list = openai_client.batches.list()
|
||||
|
||||
assert isinstance(batch_list.data, list)
|
||||
|
||||
listed_batch_ids = {b.id for b in batch_list.data}
|
||||
for batch_id in batch_ids:
|
||||
assert batch_id in listed_batch_ids
|
||||
|
||||
def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test immediate batch cancellation."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
# hopefully cancel the batch before it completes
|
||||
cancelling_batch = openai_client.batches.cancel(batch.id)
|
||||
assert cancelling_batch.status in ["cancelling", "cancelled"]
|
||||
assert isinstance(cancelling_batch.cancelling_at, int), (
|
||||
f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}"
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min
|
||||
expected_statuses={"cancelled"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
assert final_batch.status == "cancelled"
|
||||
assert isinstance(final_batch.cancelled_at, int), (
|
||||
f"cancelled_at should be int, got {type(final_batch.cancelled_at)}"
|
||||
)
|
||||
|
||||
def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test end-to-end batch processing for chat completions with both successful and failed operations."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "success-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Say hello"}],
|
||||
"max_tokens": 20,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "error-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "This should fail"}],
|
||||
"max_tokens": -1, # Invalid negative max_tokens will cause inference error
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"test": "e2e_success_and_errors_test"},
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60, # often takes 2-3 minutes
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
# Expecting a completed batch with both successful and failed requests
|
||||
# Batch(id='batch_xxx',
|
||||
# completion_window='24h',
|
||||
# created_at=...,
|
||||
# endpoint='/v1/chat/completions',
|
||||
# input_file_id='file-xxx',
|
||||
# object='batch',
|
||||
# status='completed',
|
||||
# output_file_id='file-xxx',
|
||||
# error_file_id='file-xxx',
|
||||
# request_counts=BatchRequestCounts(completed=1, failed=1, total=2))
|
||||
|
||||
assert final_batch.status == "completed"
|
||||
assert final_batch.request_counts is not None
|
||||
assert final_batch.request_counts.total == 2
|
||||
assert final_batch.request_counts.completed == 1
|
||||
assert final_batch.request_counts.failed == 1
|
||||
|
||||
assert final_batch.output_file_id is not None, "Output file should exist for successful requests"
|
||||
|
||||
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||
if isinstance(output_content, str):
|
||||
output_text = output_content
|
||||
else:
|
||||
output_text = output_content.content.decode("utf-8")
|
||||
|
||||
output_lines = output_text.strip().split("\n")
|
||||
|
||||
for line in output_lines:
|
||||
result = json.loads(line)
|
||||
|
||||
assert "id" in result
|
||||
assert "custom_id" in result
|
||||
assert result["custom_id"] == "success-1"
|
||||
|
||||
assert "response" in result
|
||||
|
||||
assert result["response"]["status_code"] == 200
|
||||
assert "body" in result["response"]
|
||||
assert "choices" in result["response"]["body"]
|
||||
|
||||
assert final_batch.error_file_id is not None, "Error file should exist for failed requests"
|
||||
|
||||
error_content = openai_client.files.content(final_batch.error_file_id)
|
||||
if isinstance(error_content, str):
|
||||
error_text = error_content
|
||||
else:
|
||||
error_text = error_content.content.decode("utf-8")
|
||||
|
||||
error_lines = error_text.strip().split("\n")
|
||||
|
||||
for line in error_lines:
|
||||
result = json.loads(line)
|
||||
|
||||
assert "id" in result
|
||||
assert "custom_id" in result
|
||||
assert result["custom_id"] == "error-1"
|
||||
assert "error" in result
|
||||
error = result["error"]
|
||||
assert error is not None
|
||||
assert "code" in error or "message" in error, "Error should have code or message"
|
||||
|
||||
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully"
|
||||
|
||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
693
tests/integration/batches/test_batches_errors.py
Normal file
693
tests/integration/batches/test_batches_errors.py
Normal file
|
@ -0,0 +1,693 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Error handling and edge case tests for the Llama Stack batch processing functionality.
|
||||
|
||||
This module focuses exclusively on testing error conditions, validation failures,
|
||||
and edge cases for batch operations to ensure robust error handling and graceful
|
||||
degradation.
|
||||
|
||||
Test Categories:
|
||||
1. File and Input Validation:
|
||||
- test_batch_nonexistent_file_id: Handling invalid file IDs
|
||||
- test_batch_malformed_jsonl: Processing malformed JSONL input files
|
||||
- test_file_malformed_batch_file: Handling malformed files at upload time
|
||||
- test_batch_missing_required_fields: Validation of required request fields
|
||||
|
||||
2. API Endpoint and Model Validation:
|
||||
- test_batch_invalid_endpoint: Invalid endpoint handling during creation
|
||||
- test_batch_error_handling_invalid_model: Error handling with nonexistent models
|
||||
- test_batch_endpoint_mismatch: Validation of endpoint/URL consistency
|
||||
|
||||
3. Batch Lifecycle Error Handling:
|
||||
- test_batch_retrieve_nonexistent: Retrieving non-existent batches
|
||||
- test_batch_cancel_nonexistent: Cancelling non-existent batches
|
||||
- test_batch_cancel_completed: Attempting to cancel completed batches
|
||||
|
||||
4. Parameter and Configuration Validation:
|
||||
- test_batch_invalid_completion_window: Invalid completion window values
|
||||
- test_batch_invalid_metadata_types: Invalid metadata type validation
|
||||
- test_batch_missing_required_body_fields: Validation of required fields in request body
|
||||
|
||||
5. Feature Restriction and Compatibility:
|
||||
- test_batch_streaming_not_supported: Streaming request rejection
|
||||
- test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation
|
||||
|
||||
Note: Core functionality and OpenAI compatibility tests are located in
|
||||
test_batches_integration.py for better organization and separation of concerns.
|
||||
|
||||
CLEANUP WARNING: These tests create batches to test error conditions but do not
|
||||
automatically clean them up after test completion. While most error tests create
|
||||
batches that fail quickly, some may create valid batches that consume resources.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from openai import BadRequestError, ConflictError, NotFoundError
|
||||
|
||||
|
||||
class TestBatchesErrorHandling:
|
||||
"""Error handling and edge case tests for the batches API using OpenAI client."""
|
||||
|
||||
def test_batch_nonexistent_file_id(self, openai_client, batch_helper):
|
||||
"""Test batch creation with nonexistent input file ID."""
|
||||
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id="file-nonexistent-xyz",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='invalid_request',
|
||||
# line=None,
|
||||
# message='Cannot find file ..., or organization ... does not have access to it.',
|
||||
# param='file_id')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566971,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.code == "invalid_request"
|
||||
assert "cannot find file" in error.message.lower()
|
||||
|
||||
def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with invalid endpoint."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-endpoint",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/invalid/endpoint",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
# Expected -
|
||||
# Error code: 400 - {
|
||||
# 'error': {
|
||||
# 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.",
|
||||
# 'type': 'invalid_request_error',
|
||||
# 'param': 'endpoint',
|
||||
# 'code': 'invalid_value'
|
||||
# }
|
||||
# }
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "invalid value" in error_msg
|
||||
assert "/v1/invalid/endpoint" in error_msg
|
||||
assert "supported values" in error_msg
|
||||
assert "endpoint" in error_msg
|
||||
assert "invalid_value" in error_msg
|
||||
|
||||
def test_batch_malformed_jsonl(self, openai_client, batch_helper):
|
||||
"""
|
||||
Test batch with malformed JSONL input.
|
||||
|
||||
The /v1/files endpoint requires valid JSONL format, so we provide a well formed line
|
||||
before a malformed line to ensure we get to the /v1/batches validation stage.
|
||||
"""
|
||||
with batch_helper.create_file(
|
||||
"""{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}}
|
||||
{invalid json here""",
|
||||
"malformed_batch_input.jsonl",
|
||||
) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# ...,
|
||||
# BatchError(code='invalid_json_line',
|
||||
# line=2,
|
||||
# message='This line is not parseable as valid JSON.',
|
||||
# param=None)
|
||||
# ], object='list'),
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) > 0
|
||||
error = final_batch.errors.data[-1] # get last error because first may be about the "test" model
|
||||
assert error.code == "invalid_json_line"
|
||||
assert error.line == 2
|
||||
assert "not" in error.message.lower()
|
||||
assert "valid json" in error.message.lower()
|
||||
|
||||
@pytest.mark.xfail(reason="Not all file providers validate content")
|
||||
@pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"])
|
||||
def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests):
|
||||
"""Test file upload with malformed content."""
|
||||
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"):
|
||||
# /v1/files rejects the file, we don't get to batch creation
|
||||
pass
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "invalid file format" in error_msg
|
||||
assert "jsonl" in error_msg
|
||||
|
||||
def test_batch_retrieve_nonexistent(self, openai_client):
|
||||
"""Test retrieving nonexistent batch."""
|
||||
with pytest.raises(NotFoundError) as exc_info:
|
||||
openai_client.batches.retrieve("batch-nonexistent-xyz")
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "no batch found" in error_msg or "not found" in error_msg
|
||||
|
||||
def test_batch_cancel_nonexistent(self, openai_client):
|
||||
"""Test cancelling nonexistent batch."""
|
||||
with pytest.raises(NotFoundError) as exc_info:
|
||||
openai_client.batches.cancel("batch-nonexistent-xyz")
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "no batch found" in error_msg or "not found" in error_msg
|
||||
|
||||
def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test cancelling already completed batch."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "cancel-completed",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Quick test"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(
|
||||
batch.id,
|
||||
max_wait_time=3 * 60, # often take 10-11 min, give it 3 min
|
||||
expected_statuses={"completed"},
|
||||
timeout_action="skip",
|
||||
)
|
||||
|
||||
deleted_file = openai_client.files.delete(final_batch.output_file_id)
|
||||
assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully"
|
||||
|
||||
with pytest.raises(ConflictError) as exc_info:
|
||||
openai_client.batches.cancel(batch.id)
|
||||
|
||||
# Expecting -
|
||||
# Error code: 409 - {
|
||||
# 'error': {
|
||||
# 'message': "Cannot cancel a batch with status 'completed'.",
|
||||
# 'type': 'invalid_request_error',
|
||||
# 'param': None,
|
||||
# 'code': None
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# NOTE: Same for "failed", cancelling "cancelled" batches is allowed
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert exc_info.value.status_code == 409
|
||||
assert "cannot cancel" in error_msg
|
||||
|
||||
def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch with requests missing required fields."""
|
||||
batch_requests = [
|
||||
{
|
||||
# Missing custom_id
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "No custom_id"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "no-method",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "No method"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "no-url",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "No URL"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "no-body",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(
|
||||
# data=[
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=1,
|
||||
# message="Missing required parameter: 'custom_id'.",
|
||||
# param='custom_id'
|
||||
# ),
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=2,
|
||||
# message="Missing required parameter: 'method'.",
|
||||
# param='method'
|
||||
# ),
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=3,
|
||||
# message="Missing required parameter: 'url'.",
|
||||
# param='url'
|
||||
# ),
|
||||
# BatchError(
|
||||
# code='missing_required_parameter',
|
||||
# line=4,
|
||||
# message="Missing required parameter: 'body'.",
|
||||
# param='body'
|
||||
# )
|
||||
# ], object='list'),
|
||||
# failed_at=1754566945,
|
||||
# ...)
|
||||
# )
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 4
|
||||
no_custom_id_error = final_batch.errors.data[0]
|
||||
assert no_custom_id_error.code == "missing_required_parameter"
|
||||
assert no_custom_id_error.line == 1
|
||||
assert "missing" in no_custom_id_error.message.lower()
|
||||
assert "custom_id" in no_custom_id_error.message.lower()
|
||||
no_method_error = final_batch.errors.data[1]
|
||||
assert no_method_error.code == "missing_required_parameter"
|
||||
assert no_method_error.line == 2
|
||||
assert "missing" in no_method_error.message.lower()
|
||||
assert "method" in no_method_error.message.lower()
|
||||
no_url_error = final_batch.errors.data[2]
|
||||
assert no_url_error.code == "missing_required_parameter"
|
||||
assert no_url_error.line == 3
|
||||
assert "missing" in no_url_error.message.lower()
|
||||
assert "url" in no_url_error.message.lower()
|
||||
no_body_error = final_batch.errors.data[3]
|
||||
assert no_body_error.code == "missing_required_parameter"
|
||||
assert no_body_error.line == 4
|
||||
assert "missing" in no_body_error.message.lower()
|
||||
assert "body" in no_body_error.message.lower()
|
||||
|
||||
def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-completion-window",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
for window in ["1h", "48h", "invalid", ""]:
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window=window,
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "error" in error_msg
|
||||
assert "completion_window" in error_msg
|
||||
|
||||
def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test that streaming responses are not supported in batches."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "streaming-test",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
"stream": True, # Not supported
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(code='streaming_unsupported',
|
||||
# line=1,
|
||||
# message='Chat Completions: Streaming is not supported in the Batch API.',
|
||||
# param='body.stream')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566965,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.code == "streaming_unsupported"
|
||||
assert error.line == 1
|
||||
assert "streaming" in error.message.lower()
|
||||
assert "not supported" in error.message.lower()
|
||||
assert error.param == "body.stream"
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id):
|
||||
"""
|
||||
Test batch with mixed streaming and non-streaming requests.
|
||||
|
||||
This is distinct from test_batch_streaming_not_supported, which tests a single
|
||||
streaming request, to ensure an otherwise valid batch fails when a single
|
||||
streaming request is included.
|
||||
"""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "valid-non-streaming-request",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello without streaming"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "streaming-request",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello with streaming"}],
|
||||
"max_tokens": 10,
|
||||
"stream": True, # Not supported
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='streaming_unsupported',
|
||||
# line=2,
|
||||
# message='Chat Completions: Streaming is not supported in the Batch API.',
|
||||
# param='body.stream')
|
||||
# ], object='list'),
|
||||
# failed_at=1754574442,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.code == "streaming_unsupported"
|
||||
assert error.line == 2
|
||||
assert "streaming" in error.message.lower()
|
||||
assert "not supported" in error.message.lower()
|
||||
assert error.param == "body.stream"
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with mismatched endpoint and request URL."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "endpoint-mismatch",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings", # Different from batch endpoint
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions", # Different from request URL
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='invalid_url',
|
||||
# line=1,
|
||||
# message='The URL provided for this request does not match the batch endpoint.',
|
||||
# param='url')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566972,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.line == 1
|
||||
assert error.code == "invalid_url"
|
||||
assert "does not match" in error.message.lower()
|
||||
assert "endpoint" in error.message.lower()
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_error_handling_invalid_model(self, openai_client, batch_helper):
|
||||
"""Test batch error handling with invalid model."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-model",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "nonexistent-model-xyz",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(code='model_not_found',
|
||||
# line=1,
|
||||
# message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.",
|
||||
# param='body.model')
|
||||
# ], object='list'),
|
||||
# failed_at=1754566978,
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 1
|
||||
error = final_batch.errors.data[0]
|
||||
assert error.line == 1
|
||||
assert error.code == "model_not_found"
|
||||
assert "not supported" in error.message.lower()
|
||||
assert error.param == "body.model"
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch with requests missing required fields in body (model and messages)."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "missing-model",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
# Missing model field
|
||||
"messages": [{"role": "user", "content": "Hello without model"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "missing-messages",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
# Missing messages field
|
||||
"max_tokens": 10,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file:
|
||||
batch = openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
)
|
||||
|
||||
final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
|
||||
|
||||
# Expecting -
|
||||
# Batch(...,
|
||||
# status='failed',
|
||||
# errors=Errors(data=[
|
||||
# BatchError(
|
||||
# code='invalid_request',
|
||||
# line=1,
|
||||
# message='Model parameter is required.',
|
||||
# param='body.model'),
|
||||
# BatchError(
|
||||
# code='invalid_request',
|
||||
# line=2,
|
||||
# message='Messages parameter is required.',
|
||||
# param='body.messages')
|
||||
# ], object='list'),
|
||||
# ...)
|
||||
|
||||
assert final_batch.status == "failed"
|
||||
assert final_batch.errors is not None
|
||||
assert len(final_batch.errors.data) == 2
|
||||
|
||||
model_error = final_batch.errors.data[0]
|
||||
assert model_error.line == 1
|
||||
assert "model" in model_error.message.lower()
|
||||
assert model_error.param == "body.model"
|
||||
|
||||
messages_error = final_batch.errors.data[1]
|
||||
assert messages_error.line == 2
|
||||
assert "messages" in messages_error.message.lower()
|
||||
assert messages_error.param == "body.messages"
|
||||
|
||||
assert final_batch.failed_at is not None
|
||||
|
||||
def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id):
|
||||
"""Test batch creation with invalid metadata types (like lists)."""
|
||||
batch_requests = [
|
||||
{
|
||||
"custom_id": "invalid-metadata-type",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": text_model_id,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
openai_client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={
|
||||
"tags": ["tag1", "tag2"], # Invalid type, should be a string
|
||||
},
|
||||
)
|
||||
|
||||
# Expecting -
|
||||
# Error code: 400 - {'error':
|
||||
# {'message': "Invalid type for 'metadata.tags': expected a string,
|
||||
# but got an array instead.",
|
||||
# 'type': 'invalid_request_error', 'param': 'metadata.tags',
|
||||
# 'code': 'invalid_type'}}
|
||||
|
||||
error_msg = str(exc_info.value).lower()
|
||||
assert "400" in error_msg
|
||||
assert "tags" in error_msg
|
||||
assert "string" in error_msg
|
753
tests/unit/providers/batches/test_reference.py
Normal file
753
tests/unit/providers/batches/test_reference.py
Normal file
|
@ -0,0 +1,753 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test suite for the reference implementation of the Batches API.
|
||||
|
||||
The tests are categorized and outlined below, keep this updated:
|
||||
|
||||
- Batch creation with various parameters and validation:
|
||||
* test_create_and_retrieve_batch_success (positive)
|
||||
* test_create_batch_without_metadata (positive)
|
||||
* test_create_batch_completion_window (negative)
|
||||
* test_create_batch_invalid_endpoints (negative)
|
||||
* test_create_batch_invalid_metadata (negative)
|
||||
|
||||
- Batch retrieval and error handling for non-existent batches:
|
||||
* test_retrieve_batch_not_found (negative)
|
||||
|
||||
- Batch cancellation with proper status transitions:
|
||||
* test_cancel_batch_success (positive)
|
||||
* test_cancel_batch_invalid_statuses (negative)
|
||||
* test_cancel_batch_not_found (negative)
|
||||
|
||||
- Batch listing with pagination and filtering:
|
||||
* test_list_batches_empty (positive)
|
||||
* test_list_batches_single_batch (positive)
|
||||
* test_list_batches_multiple_batches (positive)
|
||||
* test_list_batches_with_limit (positive)
|
||||
* test_list_batches_with_pagination (positive)
|
||||
* test_list_batches_invalid_after (negative)
|
||||
|
||||
- Data persistence in the underlying key-value store:
|
||||
* test_kvstore_persistence (positive)
|
||||
|
||||
- Batch processing concurrency control:
|
||||
* test_max_concurrent_batches (positive)
|
||||
|
||||
- Input validation testing (direct _validate_input method tests):
|
||||
* test_validate_input_file_not_found (negative)
|
||||
* test_validate_input_file_exists_empty_content (positive)
|
||||
* test_validate_input_file_mixed_valid_invalid_json (mixed)
|
||||
* test_validate_input_invalid_model (negative)
|
||||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
dependencies like inference, files, and models APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.batches import BatchObject
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
"""Test the reference implementation of the Batches API."""
|
||||
|
||||
@pytest.fixture
|
||||
async def provider(self):
|
||||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
mock_inference = AsyncMock()
|
||||
mock_files = AsyncMock()
|
||||
mock_models = AsyncMock()
|
||||
|
||||
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
|
||||
await provider.initialize()
|
||||
|
||||
# unit tests should not require background processing
|
||||
provider.process_batches = False
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.shutdown()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_data(self):
|
||||
"""Sample batch data for testing."""
|
||||
return {
|
||||
"input_file_id": "file_abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {"test": "true", "priority": "high"},
|
||||
}
|
||||
|
||||
def _validate_batch_type(self, batch, expected_metadata=None):
|
||||
"""
|
||||
Helper function to validate batch object structure and field types.
|
||||
|
||||
Note: This validates the direct BatchObject from the provider, not the
|
||||
client library response which has a different structure.
|
||||
|
||||
Args:
|
||||
batch: The BatchObject instance to validate.
|
||||
expected_metadata: Optional expected metadata dictionary to validate against.
|
||||
"""
|
||||
assert isinstance(batch.id, str)
|
||||
assert isinstance(batch.completion_window, str)
|
||||
assert isinstance(batch.created_at, int)
|
||||
assert isinstance(batch.endpoint, str)
|
||||
assert isinstance(batch.input_file_id, str)
|
||||
assert batch.object == "batch"
|
||||
assert batch.status in [
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
if expected_metadata is not None:
|
||||
assert batch.metadata == expected_metadata
|
||||
|
||||
timestamp_fields = [
|
||||
"cancelled_at",
|
||||
"cancelling_at",
|
||||
"completed_at",
|
||||
"expired_at",
|
||||
"expires_at",
|
||||
"failed_at",
|
||||
"finalizing_at",
|
||||
"in_progress_at",
|
||||
]
|
||||
for field in timestamp_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
|
||||
|
||||
file_id_fields = ["error_file_id", "output_file_id"]
|
||||
for field in file_id_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
|
||||
|
||||
if hasattr(batch, "request_counts") and batch.request_counts is not None:
|
||||
assert isinstance(batch.request_counts.completed, int), (
|
||||
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.failed, int), (
|
||||
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.total, int), (
|
||||
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
|
||||
)
|
||||
|
||||
if hasattr(batch, "errors") and batch.errors is not None:
|
||||
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
|
||||
|
||||
if hasattr(batch.errors, "data") and batch.errors.data is not None:
|
||||
assert isinstance(batch.errors.data, list), (
|
||||
f"errors.data should be list or None, got {type(batch.errors.data)}"
|
||||
)
|
||||
|
||||
for i, error_item in enumerate(batch.errors.data):
|
||||
assert isinstance(error_item, dict), (
|
||||
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "code") and error_item.code is not None:
|
||||
assert isinstance(error_item.code, str), (
|
||||
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "line") and error_item.line is not None:
|
||||
assert isinstance(error_item.line, int), (
|
||||
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "message") and error_item.message is not None:
|
||||
assert isinstance(error_item.message, str), (
|
||||
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "param") and error_item.param is not None:
|
||||
assert isinstance(error_item.param, str), (
|
||||
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
|
||||
)
|
||||
|
||||
if hasattr(batch.errors, "object") and batch.errors.object is not None:
|
||||
assert isinstance(batch.errors.object, str), (
|
||||
f"errors.object should be str or None, got {type(batch.errors.object)}"
|
||||
)
|
||||
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
|
||||
|
||||
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch creation and retrieval."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert created_batch.id.startswith("batch_")
|
||||
assert len(created_batch.id) > 13
|
||||
assert created_batch.object == "batch"
|
||||
assert created_batch.endpoint == sample_batch_data["endpoint"]
|
||||
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
|
||||
assert created_batch.completion_window == sample_batch_data["completion_window"]
|
||||
assert created_batch.status == "validating"
|
||||
assert created_batch.metadata == sample_batch_data["metadata"]
|
||||
assert isinstance(created_batch.created_at, int)
|
||||
assert created_batch.created_at > 0
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
||||
|
||||
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert retrieved_batch.id == created_batch.id
|
||||
assert retrieved_batch.input_file_id == created_batch.input_file_id
|
||||
assert retrieved_batch.endpoint == created_batch.endpoint
|
||||
assert retrieved_batch.status == created_batch.status
|
||||
assert retrieved_batch.metadata == created_batch.metadata
|
||||
|
||||
async def test_create_batch_without_metadata(self, provider):
|
||||
"""Test batch creation without optional metadata."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
assert batch.metadata is None
|
||||
|
||||
async def test_create_batch_completion_window(self, provider):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
with pytest.raises(ValueError, match="Invalid completion_window"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
)
|
||||
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||
"""Test batch creation with various invalid endpoints."""
|
||||
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
|
||||
async def test_create_batch_invalid_metadata(self, provider):
|
||||
"""Test that batch creation fails with invalid metadata."""
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={123: "invalid_key"}, # Non-string key
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"valid_key": 456}, # Non-string value
|
||||
)
|
||||
|
||||
async def test_retrieve_batch_not_found(self, provider):
|
||||
"""Test error when retrieving non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.retrieve_batch("nonexistent_batch")
|
||||
|
||||
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch cancellation."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
assert created_batch.status == "validating"
|
||||
|
||||
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
||||
|
||||
assert cancelled_batch.id == created_batch.id
|
||||
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||
assert isinstance(cancelled_batch.cancelling_at, int)
|
||||
assert cancelled_batch.cancelling_at >= created_batch.created_at
|
||||
|
||||
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
|
||||
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||
"""Test error when cancelling batch in final states."""
|
||||
provider.process_batches = False
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
# directly update status in kvstore
|
||||
await provider._update_batch(created_batch.id, status=status)
|
||||
|
||||
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||
await provider.cancel_batch(created_batch.id)
|
||||
|
||||
async def test_cancel_batch_not_found(self, provider):
|
||||
"""Test error when cancelling non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.cancel_batch("nonexistent_batch")
|
||||
|
||||
async def test_list_batches_empty(self, provider):
|
||||
"""Test listing batches when none exist."""
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.data == []
|
||||
assert response.first_id is None
|
||||
assert response.last_id is None
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||
"""Test listing batches with single batch."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 1
|
||||
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||
assert response.data[0].id == created_batch.id
|
||||
assert response.first_id == created_batch.id
|
||||
assert response.last_id == created_batch.id
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_multiple_batches(self, provider):
|
||||
"""Test listing multiple batches."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 3
|
||||
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids == expected_ids
|
||||
assert response.has_more is False
|
||||
|
||||
assert response.first_id in expected_ids
|
||||
assert response.last_id in expected_ids
|
||||
|
||||
async def test_list_batches_with_limit(self, provider):
|
||||
"""Test listing batches with limit parameter."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches(limit=2)
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert response.has_more is True
|
||||
assert response.first_id == response.data[0].id
|
||||
assert response.last_id == response.data[1].id
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids.issubset(expected_ids)
|
||||
|
||||
async def test_list_batches_with_pagination(self, provider):
|
||||
"""Test listing batches with pagination using 'after' parameter."""
|
||||
for i in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
# Get first page
|
||||
first_page = await provider.list_batches(limit=1)
|
||||
assert len(first_page.data) == 1
|
||||
assert first_page.has_more is True
|
||||
|
||||
# Get second page using 'after'
|
||||
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
||||
assert len(second_page.data) == 1
|
||||
assert second_page.data[0].id != first_page.data[0].id
|
||||
|
||||
# Verify we got the next batch in order
|
||||
all_batches = await provider.list_batches()
|
||||
expected_second_batch_id = all_batches.data[1].id
|
||||
assert second_page.data[0].id == expected_second_batch_id
|
||||
|
||||
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||
"""Test listing batches with invalid 'after' parameter."""
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches(after="nonexistent_batch")
|
||||
|
||||
# Should return all batches (no filtering when 'after' batch not found)
|
||||
assert len(response.data) == 1
|
||||
|
||||
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||
"""Test that batches are properly persisted in kvstore."""
|
||||
batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||
assert stored_data is not None
|
||||
|
||||
stored_batch_dict = json.loads(stored_data)
|
||||
assert stored_batch_dict["id"] == batch.id
|
||||
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
|
||||
|
||||
async def test_validate_input_file_not_found(self, provider):
|
||||
"""Test _validate_input when input file does not exist."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="nonexistent_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].message == "Cannot find file nonexistent_file."
|
||||
assert errors[0].param == "input_file_id"
|
||||
assert errors[0].line is None
|
||||
|
||||
async def test_validate_input_file_exists_empty_content(self, provider):
|
||||
"""Test _validate_input when file exists but is empty."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b""
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="empty_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len(requests) == 0
|
||||
|
||||
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
|
||||
"""Test _validate_input when file contains valid and invalid JSON lines."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="mixed_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 1
|
||||
|
||||
assert errors[0].code == "invalid_json_line"
|
||||
assert errors[0].line == 2
|
||||
assert errors[0].message == "This line is not parseable as valid JSON."
|
||||
|
||||
assert requests[0].custom_id == "req-1"
|
||||
assert requests[0].method == "POST"
|
||||
assert requests[0].url == "/v1/chat/completions"
|
||||
assert requests[0].body["model"] == "test-model"
|
||||
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
async def test_validate_input_invalid_model(self, provider):
|
||||
"""Test _validate_input when file contains request with non-existent model."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_model_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "model_not_found"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
|
||||
assert errors[0].param == "body.model"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
|
||||
input_file_id="url_mismatch_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_url"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
|
||||
assert errors[0].param == "url"
|
||||
|
||||
async def test_validate_input_multiple_errors_per_request(self, provider):
|
||||
"""Test _validate_input when a single request has multiple validation errors."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Request missing custom_id, has invalid URL, and missing model in body
|
||||
mock_response.body = (
|
||||
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
)
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
|
||||
input_file_id="multiple_errors_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
|
||||
assert len(requests) == 0
|
||||
|
||||
for error in errors:
|
||||
assert error.line == 1
|
||||
|
||||
error_codes = {error.code for error in errors}
|
||||
assert "missing_required_parameter" in error_codes # missing custom_id
|
||||
assert "invalid_url" in error_codes # URL mismatch
|
||||
|
||||
async def test_validate_input_invalid_request_format(self, provider):
|
||||
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'["not", "a", "request", "object"]'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_format_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Each line must be a JSON dictionary object"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,invalid_value,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
|
||||
("url", "url", 123, "URL must be a string"),
|
||||
("method", "method", ["POST"], "Method must be a string"),
|
||||
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
|
||||
("model", "body.model", 123, "Model must be a string"),
|
||||
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_invalid_parameter_types(
|
||||
self, provider, param_name, param_path, invalid_value, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with parameters that have invalid types."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Override the specific parameter with invalid value
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
base_request[top_level][nested_param] = invalid_value
|
||||
else:
|
||||
base_request[param_name] = invalid_value
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"invalid_{param_name}_type_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_max_concurrent_batches(self, provider):
|
||||
"""Test max_concurrent_batches configuration and concurrency control."""
|
||||
import asyncio
|
||||
|
||||
provider._batch_semaphore = asyncio.Semaphore(2)
|
||||
|
||||
provider.process_batches = True # enable because we're testing background processing
|
||||
|
||||
active_batches = 0
|
||||
|
||||
async def add_and_wait(batch_id: str):
|
||||
nonlocal active_batches
|
||||
active_batches += 1
|
||||
await asyncio.sleep(float("inf"))
|
||||
|
||||
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
|
||||
# so we can replace _process_batch_impl with our mock to control concurrency
|
||||
provider._process_batch_impl = add_and_wait
|
||||
|
||||
for _ in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
Loading…
Add table
Add a link
Reference in a new issue