diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index b36626719..0549dda21 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -14767,8 +14767,7 @@
"OpenAIFilePurpose": {
"type": "string",
"enum": [
- "assistants",
- "batch"
+ "assistants"
],
"title": "OpenAIFilePurpose",
"description": "Valid purpose values for OpenAI Files API."
@@ -14845,8 +14844,7 @@
"purpose": {
"type": "string",
"enum": [
- "assistants",
- "batch"
+ "assistants"
],
"description": "The intended purpose of the file"
}
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index e7733b3c3..aa47cd58d 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -10951,7 +10951,6 @@ components:
type: string
enum:
- assistants
- - batch
title: OpenAIFilePurpose
description: >-
Valid purpose values for OpenAI Files API.
@@ -11020,7 +11019,6 @@ components:
type: string
enum:
- assistants
- - batch
description: The intended purpose of the file
additionalProperties: false
required:
diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md
index f8f73a928..5a10d6498 100644
--- a/docs/source/concepts/apis.md
+++ b/docs/source/concepts/apis.md
@@ -18,4 +18,3 @@ We are working on adding a few more APIs to complete the application lifecycle.
- **Batch Inference**: run inference on a dataset of inputs
- **Batch Agents**: run agents on a dataset of inputs
- **Synthetic Data Generation**: generate synthetic data for model development
-- **Batches**: OpenAI-compatible batch management for inference
diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md
index a2c48d4b9..92bf9edc0 100644
--- a/docs/source/providers/agents/index.md
+++ b/docs/source/providers/agents/index.md
@@ -2,15 +2,6 @@
## Overview
-Agents API for creating and interacting with agentic systems.
-
- Main functionalities provided by this API:
- - Create agents with specific instructions and ability to use tools.
- - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn".
- - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details).
- - Agents can be provided with various shields (see the Safety API for more details).
- - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details.
-
This section contains documentation for all available providers for the **agents** API.
## Providers
diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md
deleted file mode 100644
index 2a39a626c..000000000
--- a/docs/source/providers/batches/index.md
+++ /dev/null
@@ -1,21 +0,0 @@
-# Batches
-
-## Overview
-
-Protocol for batch processing API operations.
-
- The Batches API enables efficient processing of multiple requests in a single operation,
- particularly useful for processing large datasets, batch evaluation workflows, and
- cost-effective inference at scale.
-
- Note: This API is currently under active development and may undergo changes.
-
-This section contains documentation for all available providers for the **batches** API.
-
-## Providers
-
-```{toctree}
-:maxdepth: 1
-
-inline_reference
-```
diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md
deleted file mode 100644
index a58e5124d..000000000
--- a/docs/source/providers/batches/inline_reference.md
+++ /dev/null
@@ -1,23 +0,0 @@
-# inline::reference
-
-## Description
-
-Reference implementation of batches API with KVStore persistence.
-
-## Configuration
-
-| Field | Type | Required | Default | Description |
-|-------|------|----------|---------|-------------|
-| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. |
-| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. |
-| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. |
-
-## Sample Configuration
-
-```yaml
-kvstore:
- type: sqlite
- db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db
-
-```
-
diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md
index a14fada1d..d180d256c 100644
--- a/docs/source/providers/eval/index.md
+++ b/docs/source/providers/eval/index.md
@@ -2,8 +2,6 @@
## Overview
-Llama Stack Evaluation API for running evaluations on model and agent candidates.
-
This section contains documentation for all available providers for the **eval** API.
## Providers
diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md
index b6d215474..38781e5eb 100644
--- a/docs/source/providers/inference/index.md
+++ b/docs/source/providers/inference/index.md
@@ -2,12 +2,6 @@
## Overview
-Llama Stack Inference API for generating completions, chat completions, and embeddings.
-
- This API provides the raw interface to the underlying models. Two kinds of models are supported:
- - LLM models: these models generate "raw" and "chat" (conversational) completions.
- - Embedding models: these models generate embeddings to be used for semantic search.
-
This section contains documentation for all available providers for the **inference** API.
## Providers
diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py
deleted file mode 100644
index 9ce7d3d75..000000000
--- a/llama_stack/apis/batches/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-from .batches import Batches, BatchObject, ListBatchesResponse
-
-__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py
deleted file mode 100644
index 9297d8597..000000000
--- a/llama_stack/apis/batches/batches.py
+++ /dev/null
@@ -1,89 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-from typing import Literal, Protocol, runtime_checkable
-
-from pydantic import BaseModel, Field
-
-from llama_stack.schema_utils import json_schema_type, webmethod
-
-try:
- from openai.types import Batch as BatchObject
-except ImportError as e:
- raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
-
-
-@json_schema_type
-class ListBatchesResponse(BaseModel):
- """Response containing a list of batch objects."""
-
- object: Literal["list"] = "list"
- data: list[BatchObject] = Field(..., description="List of batch objects")
- first_id: str | None = Field(default=None, description="ID of the first batch in the list")
- last_id: str | None = Field(default=None, description="ID of the last batch in the list")
- has_more: bool = Field(default=False, description="Whether there are more batches available")
-
-
-@runtime_checkable
-class Batches(Protocol):
- """Protocol for batch processing API operations.
-
- The Batches API enables efficient processing of multiple requests in a single operation,
- particularly useful for processing large datasets, batch evaluation workflows, and
- cost-effective inference at scale.
-
- Note: This API is currently under active development and may undergo changes.
- """
-
- @webmethod(route="/openai/v1/batches", method="POST")
- async def create_batch(
- self,
- input_file_id: str,
- endpoint: str,
- completion_window: Literal["24h"],
- metadata: dict[str, str] | None = None,
- ) -> BatchObject:
- """Create a new batch for processing multiple API requests.
-
- :param input_file_id: The ID of an uploaded file containing requests for the batch.
- :param endpoint: The endpoint to be used for all requests in the batch.
- :param completion_window: The time window within which the batch should be processed.
- :param metadata: Optional metadata for the batch.
- :returns: The created batch object.
- """
- ...
-
- @webmethod(route="/openai/v1/batches/{batch_id}", method="GET")
- async def retrieve_batch(self, batch_id: str) -> BatchObject:
- """Retrieve information about a specific batch.
-
- :param batch_id: The ID of the batch to retrieve.
- :returns: The batch object.
- """
- ...
-
- @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST")
- async def cancel_batch(self, batch_id: str) -> BatchObject:
- """Cancel a batch that is in progress.
-
- :param batch_id: The ID of the batch to cancel.
- :returns: The updated batch object.
- """
- ...
-
- @webmethod(route="/openai/v1/batches", method="GET")
- async def list_batches(
- self,
- after: str | None = None,
- limit: int = 20,
- ) -> ListBatchesResponse:
- """List all batches for the current user.
-
- :param after: A cursor for pagination; returns batches after this batch ID.
- :param limit: Number of batches to return (default 20, max 100).
- :returns: A list of batch objects.
- """
- ...
diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py
index 7104d8db6..6e0fa0b3c 100644
--- a/llama_stack/apis/common/errors.py
+++ b/llama_stack/apis/common/errors.py
@@ -64,12 +64,6 @@ class SessionNotFoundError(ValueError):
super().__init__(message)
-class ConflictError(ValueError):
- """raised when an operation cannot be performed due to a conflict with the current state"""
-
- pass
-
-
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""
diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py
index 87fc95917..cabe46a2f 100644
--- a/llama_stack/apis/datatypes.py
+++ b/llama_stack/apis/datatypes.py
@@ -86,7 +86,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
:cvar inference: Text generation, chat completions, and embeddings
:cvar safety: Content moderation and safety shields
:cvar agents: Agent orchestration and execution
- :cvar batches: Batch processing for asynchronous API requests
:cvar vector_io: Vector database operations and queries
:cvar datasetio: Dataset input/output operations
:cvar scoring: Model output evaluation and scoring
@@ -109,7 +108,6 @@ class Api(Enum, metaclass=DynamicApiMeta):
inference = "inference"
safety = "safety"
agents = "agents"
- batches = "batches"
vector_io = "vector_io"
datasetio = "datasetio"
scoring = "scoring"
diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py
index a1b9dd4dc..ba8701e23 100644
--- a/llama_stack/apis/files/files.py
+++ b/llama_stack/apis/files/files.py
@@ -22,7 +22,6 @@ class OpenAIFilePurpose(StrEnum):
"""
ASSISTANTS = "assistants"
- BATCH = "batch"
# TODO: Add other purposes as needed
diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py
index 7ac98dac8..70c78fb01 100644
--- a/llama_stack/core/resolver.py
+++ b/llama_stack/core/resolver.py
@@ -8,7 +8,6 @@ import inspect
from typing import Any
from llama_stack.apis.agents import Agents
-from llama_stack.apis.batches import Batches
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@@ -76,7 +75,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) ->
Api.agents: Agents,
Api.inference: Inference,
Api.inspect: Inspect,
- Api.batches: Batches,
Api.vector_io: VectorIO,
Api.vector_dbs: VectorDBs,
Api.models: Models,
diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py
index cbef8ef88..e9d70fc8d 100644
--- a/llama_stack/core/server/server.py
+++ b/llama_stack/core/server/server.py
@@ -32,7 +32,6 @@ from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
-from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.cli.utils import add_config_distro_args, get_config_from_args
from llama_stack.core.access_control.access_control import AccessDeniedError
@@ -129,10 +128,6 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
]
},
)
- elif isinstance(exc, ConflictError):
- return HTTPException(status_code=409, detail=str(exc))
- elif isinstance(exc, ResourceNotFoundError):
- return HTTPException(status_code=404, detail=str(exc))
elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py
deleted file mode 100644
index 756f351d8..000000000
--- a/llama_stack/providers/inline/batches/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py
deleted file mode 100644
index a8ae92eb2..000000000
--- a/llama_stack/providers/inline/batches/reference/__init__.py
+++ /dev/null
@@ -1,36 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-from typing import Any
-
-from llama_stack.apis.files import Files
-from llama_stack.apis.inference import Inference
-from llama_stack.apis.models import Models
-from llama_stack.core.datatypes import AccessRule, Api
-from llama_stack.providers.utils.kvstore import kvstore_impl
-
-from .batches import ReferenceBatchesImpl
-from .config import ReferenceBatchesImplConfig
-
-__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"]
-
-
-async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]):
- kvstore = await kvstore_impl(config.kvstore)
- inference_api: Inference | None = deps.get(Api.inference)
- files_api: Files | None = deps.get(Api.files)
- models_api: Models | None = deps.get(Api.models)
-
- if inference_api is None:
- raise ValueError("Inference API is required but not provided in dependencies")
- if files_api is None:
- raise ValueError("Files API is required but not provided in dependencies")
- if models_api is None:
- raise ValueError("Models API is required but not provided in dependencies")
-
- impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore)
- await impl.initialize()
- return impl
diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py
deleted file mode 100644
index 984ef5a90..000000000
--- a/llama_stack/providers/inline/batches/reference/batches.py
+++ /dev/null
@@ -1,553 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-import asyncio
-import itertools
-import json
-import time
-import uuid
-from io import BytesIO
-from typing import Any, Literal
-
-from openai.types.batch import BatchError, Errors
-from pydantic import BaseModel
-
-from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse
-from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
-from llama_stack.apis.files import Files, OpenAIFilePurpose
-from llama_stack.apis.inference import Inference
-from llama_stack.apis.models import Models
-from llama_stack.log import get_logger
-from llama_stack.providers.utils.kvstore import KVStore
-
-from .config import ReferenceBatchesImplConfig
-
-BATCH_PREFIX = "batch:"
-
-logger = get_logger(__name__)
-
-
-class AsyncBytesIO:
- """
- Async-compatible BytesIO wrapper to allow async file-like operations.
-
- We use this when uploading files to the Files API, as it expects an
- async file-like object.
- """
-
- def __init__(self, data: bytes):
- self._buffer = BytesIO(data)
-
- async def read(self, n=-1):
- return self._buffer.read(n)
-
- async def seek(self, pos, whence=0):
- return self._buffer.seek(pos, whence)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._buffer.close()
-
- def __getattr__(self, name):
- return getattr(self._buffer, name)
-
-
-class BatchRequest(BaseModel):
- line_num: int
- custom_id: str
- method: str
- url: str
- body: dict[str, Any]
-
-
-class ReferenceBatchesImpl(Batches):
- """Reference implementation of the Batches API.
-
- This implementation processes batch files by making individual requests
- to the inference API and generates output files with results.
- """
-
- def __init__(
- self,
- config: ReferenceBatchesImplConfig,
- inference_api: Inference,
- files_api: Files,
- models_api: Models,
- kvstore: KVStore,
- ) -> None:
- self.config = config
- self.kvstore = kvstore
- self.inference_api = inference_api
- self.files_api = files_api
- self.models_api = models_api
- self._processing_tasks: dict[str, asyncio.Task] = {}
- self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches)
- self._update_batch_lock = asyncio.Lock()
-
- # this is to allow tests to disable background processing
- self.process_batches = True
-
- async def initialize(self) -> None:
- # TODO: start background processing of existing tasks
- pass
-
- async def shutdown(self) -> None:
- """Shutdown the batches provider."""
- if self._processing_tasks:
- # don't cancel tasks - just let them stop naturally on shutdown
- # cancelling would mark batches as "cancelled" in the database
- logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks")
-
- # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
- async def create_batch(
- self,
- input_file_id: str,
- endpoint: str,
- completion_window: Literal["24h"],
- metadata: dict[str, str] | None = None,
- ) -> BatchObject:
- """
- Create a new batch for processing multiple API requests.
-
- Error handling by levels -
- 0. Input param handling, results in 40x errors before processing, e.g.
- - Wrong completion_window
- - Invalid metadata types
- - Unknown endpoint
- -> no batch created
- 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g.
- - input_file_id missing
- - invalid json in file
- - missing custom_id, method, url, body
- - invalid model
- - streaming
- -> batch created, validation sends to failed status
- 2. Processing errors, result in error_file_id entries, e.g.
- - Any error returned from inference endpoint
- -> batch created, goes to completed status
- """
-
- # TODO: set expiration time for garbage collection
-
- if endpoint not in ["/v1/chat/completions"]:
- raise ValueError(
- f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
- )
-
- if completion_window != "24h":
- raise ValueError(
- f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
- )
-
- batch_id = f"batch_{uuid.uuid4().hex[:16]}"
- current_time = int(time.time())
-
- batch = BatchObject(
- id=batch_id,
- object="batch",
- endpoint=endpoint,
- input_file_id=input_file_id,
- completion_window=completion_window,
- status="validating",
- created_at=current_time,
- metadata=metadata,
- )
-
- await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
-
- if self.process_batches:
- task = asyncio.create_task(self._process_batch(batch_id))
- self._processing_tasks[batch_id] = task
-
- return batch
-
- async def cancel_batch(self, batch_id: str) -> BatchObject:
- """Cancel a batch that is in progress."""
- batch = await self.retrieve_batch(batch_id)
-
- if batch.status in ["cancelled", "cancelling"]:
- return batch
-
- if batch.status in ["completed", "failed", "expired"]:
- raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'")
-
- await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time()))
-
- if batch_id in self._processing_tasks:
- self._processing_tasks[batch_id].cancel()
- # note: task removal and status="cancelled" handled in finally block of _process_batch
-
- return await self.retrieve_batch(batch_id)
-
- async def list_batches(
- self,
- after: str | None = None,
- limit: int = 20,
- ) -> ListBatchesResponse:
- """
- List all batches, eventually only for the current user.
-
- With no notion of user, we return all batches.
- """
- batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff")
-
- batches = []
- for batch_data in batch_values:
- if batch_data:
- batches.append(BatchObject.model_validate_json(batch_data))
-
- batches.sort(key=lambda b: b.created_at, reverse=True)
-
- start_idx = 0
- if after:
- for i, batch in enumerate(batches):
- if batch.id == after:
- start_idx = i + 1
- break
-
- page_batches = batches[start_idx : start_idx + limit]
- has_more = (start_idx + limit) < len(batches)
-
- first_id = page_batches[0].id if page_batches else None
- last_id = page_batches[-1].id if page_batches else None
-
- return ListBatchesResponse(
- data=page_batches,
- first_id=first_id,
- last_id=last_id,
- has_more=has_more,
- )
-
- async def retrieve_batch(self, batch_id: str) -> BatchObject:
- """Retrieve information about a specific batch."""
- batch_data = await self.kvstore.get(f"batch:{batch_id}")
- if not batch_data:
- raise ResourceNotFoundError(batch_id, "Batch", "batches.list()")
-
- return BatchObject.model_validate_json(batch_data)
-
- async def _update_batch(self, batch_id: str, **updates) -> None:
- """Update batch fields in kvstore."""
- async with self._update_batch_lock:
- try:
- batch = await self.retrieve_batch(batch_id)
-
- # batch processing is async. once cancelling, only allow "cancelled" status updates
- if batch.status == "cancelling" and updates.get("status") != "cancelled":
- logger.info(
- f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}"
- )
- return
-
- if "errors" in updates:
- updates["errors"] = updates["errors"].model_dump()
-
- batch_dict = batch.model_dump()
- batch_dict.update(updates)
-
- await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict))
- except Exception as e:
- logger.error(f"Failed to update batch {batch_id}: {e}")
-
- async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]:
- """
- Read & validate input, return errors and valid input.
-
- Validation of
- - input_file_id existance
- - valid json
- - custom_id, method, url, body presence and valid
- - no streaming
- """
- requests: list[BatchRequest] = []
- errors: list[BatchError] = []
- try:
- await self.files_api.openai_retrieve_file(batch.input_file_id)
- except Exception:
- errors.append(
- BatchError(
- code="invalid_request",
- line=None,
- message=f"Cannot find file {batch.input_file_id}.",
- param="input_file_id",
- )
- )
- return errors, requests
-
- # TODO(SECURITY): do something about large files
- file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id)
- file_content = file_content_response.body.decode("utf-8")
- for line_num, line in enumerate(file_content.strip().split("\n"), 1):
- if line.strip(): # skip empty lines
- try:
- request = json.loads(line)
-
- if not isinstance(request, dict):
- errors.append(
- BatchError(
- code="invalid_request",
- line=line_num,
- message="Each line must be a JSON dictionary object",
- )
- )
- continue
-
- valid = True
-
- for param, expected_type, type_string in [
- ("custom_id", str, "string"),
- ("method", str, "string"),
- ("url", str, "string"),
- ("body", dict, "JSON dictionary object"),
- ]:
- if param not in request:
- errors.append(
- BatchError(
- code="missing_required_parameter",
- line=line_num,
- message=f"Missing required parameter: {param}",
- param=param,
- )
- )
- valid = False
- elif not isinstance(request[param], expected_type):
- param_name = "URL" if param == "url" else param.capitalize()
- errors.append(
- BatchError(
- code="invalid_request",
- line=line_num,
- message=f"{param_name} must be a {type_string}",
- param=param,
- )
- )
- valid = False
-
- if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint:
- errors.append(
- BatchError(
- code="invalid_url",
- line=line_num,
- message="URL provided for this request does not match the batch endpoint",
- param="url",
- )
- )
- valid = False
-
- if (body := request.get("body")) and isinstance(body, dict):
- if body.get("stream", False):
- errors.append(
- BatchError(
- code="streaming_unsupported",
- line=line_num,
- message="Streaming is not supported in batch processing",
- param="body.stream",
- )
- )
- valid = False
-
- for param, expected_type, type_string in [
- ("model", str, "a string"),
- # messages is specific to /v1/chat/completions
- # we could skip validating messages here and let inference fail. however,
- # that would be a very expensive way to find out messages is wrong.
- ("messages", list, "an array"), # TODO: allow messages to be a string?
- ]:
- if param not in body:
- errors.append(
- BatchError(
- code="invalid_request",
- line=line_num,
- message=f"{param.capitalize()} parameter is required",
- param=f"body.{param}",
- )
- )
- valid = False
- elif not isinstance(body[param], expected_type):
- errors.append(
- BatchError(
- code="invalid_request",
- line=line_num,
- message=f"{param.capitalize()} must be {type_string}",
- param=f"body.{param}",
- )
- )
- valid = False
-
- if "model" in body and isinstance(body["model"], str):
- try:
- await self.models_api.get_model(body["model"])
- except Exception:
- errors.append(
- BatchError(
- code="model_not_found",
- line=line_num,
- message=f"Model '{body['model']}' does not exist or is not supported",
- param="body.model",
- )
- )
- valid = False
-
- if valid:
- assert isinstance(url, str), "URL must be a string" # for mypy
- assert isinstance(body, dict), "Body must be a dictionary" # for mypy
- requests.append(
- BatchRequest(
- line_num=line_num,
- url=url,
- method=request["method"],
- custom_id=request["custom_id"],
- body=body,
- ),
- )
- except json.JSONDecodeError:
- errors.append(
- BatchError(
- code="invalid_json_line",
- line=line_num,
- message="This line is not parseable as valid JSON.",
- )
- )
-
- return errors, requests
-
- async def _process_batch(self, batch_id: str) -> None:
- """Background task to process a batch of requests."""
- try:
- logger.info(f"Starting batch processing for {batch_id}")
- async with self._batch_semaphore: # semaphore to limit concurrency
- logger.info(f"Acquired semaphore for batch {batch_id}")
- await self._process_batch_impl(batch_id)
- except asyncio.CancelledError:
- logger.info(f"Batch processing cancelled for {batch_id}")
- await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time()))
- except Exception as e:
- logger.error(f"Batch processing failed for {batch_id}: {e}")
- await self._update_batch(
- batch_id,
- status="failed",
- failed_at=int(time.time()),
- errors=Errors(data=[BatchError(code="internal_error", message=str(e))]),
- )
- finally:
- self._processing_tasks.pop(batch_id, None)
-
- async def _process_batch_impl(self, batch_id: str) -> None:
- """Implementation of batch processing logic."""
- errors: list[BatchError] = []
- batch = await self.retrieve_batch(batch_id)
-
- errors, requests = await self._validate_input(batch)
- if errors:
- await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors))
- logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors")
- return
-
- logger.info(f"Processing {len(requests)} requests for batch {batch_id}")
-
- total_requests = len(requests)
- await self._update_batch(
- batch_id,
- status="in_progress",
- request_counts={"total": total_requests, "completed": 0, "failed": 0},
- )
-
- error_results = []
- success_results = []
- completed_count = 0
- failed_count = 0
-
- for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch):
- # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled
- async with asyncio.TaskGroup() as tg:
- chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk]
-
- chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True)
-
- for result in chunk_results:
- if isinstance(result, dict) and result.get("error") is not None: # error response from inference
- failed_count += 1
- error_results.append(result)
- elif isinstance(result, dict) and result.get("response") is not None: # successful inference
- completed_count += 1
- success_results.append(result)
- else: # unexpected result
- failed_count += 1
- errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}"))
-
- await self._update_batch(
- batch_id,
- request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count},
- )
-
- if errors:
- await self._update_batch(
- batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)
- )
- return
-
- try:
- output_file_id = await self._create_output_file(batch_id, success_results, "success")
- await self._update_batch(batch_id, output_file_id=output_file_id)
-
- error_file_id = await self._create_output_file(batch_id, error_results, "error")
- await self._update_batch(batch_id, error_file_id=error_file_id)
-
- await self._update_batch(batch_id, status="completed", completed_at=int(time.time()))
-
- logger.info(
- f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed"
- )
- except Exception as e:
- # note: errors is empty at this point, so we don't lose anything by ignoring it
- await self._update_batch(
- batch_id,
- status="failed",
- failed_at=int(time.time()),
- errors=Errors(data=[BatchError(code="output_failed", message=str(e))]),
- )
-
- async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict:
- """Process a single request from the batch."""
- request_id = f"batch_req_{batch_id}_{request.line_num}"
-
- try:
- # TODO(SECURITY): review body for security issues
- chat_response = await self.inference_api.openai_chat_completion(**request.body)
-
- # this is for mypy, we don't allow streaming so we'll get the right type
- assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
- return {
- "id": request_id,
- "custom_id": request.custom_id,
- "response": {
- "status_code": 200,
- "request_id": request_id, # TODO: should this be different?
- "body": chat_response.model_dump_json(),
- },
- }
- except Exception as e:
- logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
- return {
- "id": request_id,
- "custom_id": request.custom_id,
- "error": {"type": "request_failed", "message": str(e)},
- }
-
- async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str:
- """
- Create an output file with batch results.
-
- This function filters results based on the specified file_type
- and uploads the file to the Files API.
- """
- output_lines = [json.dumps(result) for result in results]
-
- with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer:
- file_buffer.filename = f"{batch_id}_{file_type}.jsonl"
- uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
- return uploaded_file.id
diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py
deleted file mode 100644
index d8d06868b..000000000
--- a/llama_stack/providers/inline/batches/reference/config.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-from pydantic import BaseModel, Field
-
-from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
-
-
-class ReferenceBatchesImplConfig(BaseModel):
- """Configuration for the Reference Batches implementation."""
-
- kvstore: KVStoreConfig = Field(
- description="Configuration for the key-value store backend.",
- )
-
- max_concurrent_batches: int = Field(
- default=1,
- description="Maximum number of concurrent batches to process simultaneously.",
- ge=1,
- )
-
- max_concurrent_requests_per_batch: int = Field(
- default=10,
- description="Maximum number of concurrent requests to process per batch.",
- ge=1,
- )
-
- # TODO: add a max requests per second rate limiter
-
- @classmethod
- def sample_run_config(cls, __distro_dir__: str) -> dict:
- return {
- "kvstore": SqliteKVStoreConfig.sample_run_config(
- __distro_dir__=__distro_dir__,
- db_name="batches.db",
- ),
- }
diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py
deleted file mode 100644
index de7886efb..000000000
--- a/llama_stack/providers/registry/batches.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-
-from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
-
-
-def available_providers() -> list[ProviderSpec]:
- return [
- InlineProviderSpec(
- api=Api.batches,
- provider_type="inline::reference",
- pip_packages=["openai"],
- module="llama_stack.providers.inline.batches.reference",
- config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig",
- api_dependencies=[
- Api.inference,
- Api.files,
- Api.models,
- ],
- description="Reference implementation of batches API with KVStore persistence.",
- ),
- ]
diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py
index 060acfa72..717677c52 100755
--- a/scripts/provider_codegen.py
+++ b/scripts/provider_codegen.py
@@ -18,23 +18,6 @@ from llama_stack.core.distribution import get_provider_registry
REPO_ROOT = Path(__file__).parent.parent
-def get_api_docstring(api_name: str) -> str | None:
- """Extract docstring from the API protocol class."""
- try:
- # Import the API module dynamically
- api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()])
-
- # Get the main protocol class (usually capitalized API name)
- protocol_class_name = api_name.title()
- if hasattr(api_module, protocol_class_name):
- protocol_class = getattr(api_module, protocol_class_name)
- return protocol_class.__doc__
- except (ImportError, AttributeError):
- pass
-
- return None
-
-
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
@@ -278,11 +261,6 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N
index_content.append(f"# {api_name.title()}\n")
index_content.append("## Overview\n")
- api_docstring = get_api_docstring(api_name)
- if api_docstring:
- cleaned_docstring = api_docstring.strip()
- index_content.append(f"{cleaned_docstring}\n")
-
index_content.append(
f"This section contains documentation for all available providers for the **{api_name}** API.\n"
)
diff --git a/tests/integration/batches/__init__.py b/tests/integration/batches/__init__.py
deleted file mode 100644
index 756f351d8..000000000
--- a/tests/integration/batches/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
diff --git a/tests/integration/batches/conftest.py b/tests/integration/batches/conftest.py
deleted file mode 100644
index 974fe77ab..000000000
--- a/tests/integration/batches/conftest.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-"""Shared pytest fixtures for batch tests."""
-
-import json
-import time
-import warnings
-from contextlib import contextmanager
-from io import BytesIO
-
-import pytest
-
-from llama_stack.apis.files import OpenAIFilePurpose
-
-
-class BatchHelper:
- """Helper class for creating and managing batch input files."""
-
- def __init__(self, client):
- """Initialize with either a batch_client or openai_client."""
- self.client = client
-
- @contextmanager
- def create_file(self, content: str | list[dict], filename_prefix="batch_input"):
- """Context manager for creating and cleaning up batch input files.
-
- Args:
- content: Either a list of batch request dictionaries or raw string content
- filename_prefix: Prefix for the generated filename (or full filename if content is string)
-
- Yields:
- The uploaded file object
- """
- if isinstance(content, str):
- # Handle raw string content (e.g., malformed JSONL, empty files)
- file_content = content.encode("utf-8")
- else:
- # Handle list of batch request dictionaries
- jsonl_content = "\n".join(json.dumps(req) for req in content)
- file_content = jsonl_content.encode("utf-8")
-
- filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl"
-
- with BytesIO(file_content) as file_buffer:
- file_buffer.name = filename
- uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH)
-
- try:
- yield uploaded_file
- finally:
- try:
- self.client.files.delete(uploaded_file.id)
- except Exception:
- warnings.warn(
- f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}",
- stacklevel=2,
- )
-
- def wait_for(
- self,
- batch_id: str,
- max_wait_time: int = 60,
- sleep_interval: int | None = None,
- expected_statuses: set[str] | None = None,
- timeout_action: str = "fail",
- ):
- """Wait for a batch to reach a terminal status.
-
- Args:
- batch_id: The batch ID to monitor
- max_wait_time: Maximum time to wait in seconds (default: 60 seconds)
- sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s)
- expected_statuses: Set of expected terminal statuses (default: {"completed"})
- timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip)
-
- Returns:
- The final batch object
-
- Raises:
- pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail"
- pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status
- """
- if sleep_interval is None:
- # Default to 1/10th of max_wait_time, with min 1s and max 15s
- sleep_interval = max(1, min(15, max_wait_time // 10))
-
- if expected_statuses is None:
- expected_statuses = {"completed"}
-
- terminal_statuses = {"completed", "failed", "cancelled", "expired"}
- unexpected_statuses = terminal_statuses - expected_statuses
-
- start_time = time.time()
- while time.time() - start_time < max_wait_time:
- current_batch = self.client.batches.retrieve(batch_id)
-
- if current_batch.status in expected_statuses:
- return current_batch
- elif current_batch.status in unexpected_statuses:
- error_msg = f"Batch reached unexpected status: {current_batch.status}"
- if timeout_action == "skip":
- pytest.skip(error_msg)
- else:
- pytest.fail(error_msg)
-
- time.sleep(sleep_interval)
-
- timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds"
- if timeout_action == "skip":
- pytest.skip(timeout_msg)
- else:
- pytest.fail(timeout_msg)
-
-
-@pytest.fixture
-def batch_helper(openai_client):
- """Fixture that provides a BatchHelper instance for OpenAI client."""
- return BatchHelper(openai_client)
diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py
deleted file mode 100644
index 1ef3202d0..000000000
--- a/tests/integration/batches/test_batches.py
+++ /dev/null
@@ -1,270 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-"""
-Integration tests for the Llama Stack batch processing functionality.
-
-This module contains comprehensive integration tests for the batch processing API,
-using the OpenAI-compatible client interface for consistency.
-
-Test Categories:
- 1. Core Batch Operations:
- - test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval
- - test_batch_listing: Basic batch listing functionality
- - test_batch_immediate_cancellation: Batch cancellation workflow
- # TODO: cancel during processing
-
- 2. End-to-End Processing:
- - test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation
-
-Note: Error conditions and edge cases are primarily tested in test_batches_errors.py
-for better organization and separation of concerns.
-
-CLEANUP WARNING: These tests currently create batches that are not automatically
-cleaned up after test completion. This may lead to resource accumulation over
-multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch.
-The test_batch_e2e_chat_completions test does clean up its output and error files.
-"""
-
-import json
-
-
-class TestBatchesIntegration:
- """Integration tests for the batches API."""
-
- def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id):
- """Test comprehensive batch creation and retrieval scenarios."""
- test_metadata = {
- "test_type": "comprehensive",
- "purpose": "creation_and_retrieval_test",
- "version": "1.0",
- "tags": "test,batch",
- }
-
- batch_requests = [
- {
- "custom_id": "request-1",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello"}],
- "max_tokens": 10,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- metadata=test_metadata,
- )
-
- assert batch.endpoint == "/v1/chat/completions"
- assert batch.input_file_id == uploaded_file.id
- assert batch.completion_window == "24h"
- assert batch.metadata == test_metadata
-
- retrieved_batch = openai_client.batches.retrieve(batch.id)
-
- assert retrieved_batch.id == batch.id
- assert retrieved_batch.object == batch.object
- assert retrieved_batch.endpoint == batch.endpoint
- assert retrieved_batch.input_file_id == batch.input_file_id
- assert retrieved_batch.completion_window == batch.completion_window
- assert retrieved_batch.metadata == batch.metadata
-
- def test_batch_listing(self, openai_client, batch_helper, text_model_id):
- """
- Test batch listing.
-
- This test creates multiple batches and verifies that they can be listed.
- It also deletes the input files before execution, which means the batches
- will appear as failed due to missing input files. This is expected and
- a good thing, because it means no inference is performed.
- """
- batch_ids = []
-
- for i in range(2):
- batch_requests = [
- {
- "custom_id": f"request-{i}",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": f"Hello {i}"}],
- "max_tokens": 10,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
- batch_ids.append(batch.id)
-
- batch_list = openai_client.batches.list()
-
- assert isinstance(batch_list.data, list)
-
- listed_batch_ids = {b.id for b in batch_list.data}
- for batch_id in batch_ids:
- assert batch_id in listed_batch_ids
-
- def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id):
- """Test immediate batch cancellation."""
- batch_requests = [
- {
- "custom_id": "request-1",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello"}],
- "max_tokens": 10,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests) as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- # hopefully cancel the batch before it completes
- cancelling_batch = openai_client.batches.cancel(batch.id)
- assert cancelling_batch.status in ["cancelling", "cancelled"]
- assert isinstance(cancelling_batch.cancelling_at, int), (
- f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}"
- )
-
- final_batch = batch_helper.wait_for(
- batch.id,
- max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min
- expected_statuses={"cancelled"},
- timeout_action="skip",
- )
-
- assert final_batch.status == "cancelled"
- assert isinstance(final_batch.cancelled_at, int), (
- f"cancelled_at should be int, got {type(final_batch.cancelled_at)}"
- )
-
- def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id):
- """Test end-to-end batch processing for chat completions with both successful and failed operations."""
- batch_requests = [
- {
- "custom_id": "success-1",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Say hello"}],
- "max_tokens": 20,
- },
- },
- {
- "custom_id": "error-1",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "This should fail"}],
- "max_tokens": -1, # Invalid negative max_tokens will cause inference error
- },
- },
- ]
-
- with batch_helper.create_file(batch_requests) as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- metadata={"test": "e2e_success_and_errors_test"},
- )
-
- final_batch = batch_helper.wait_for(
- batch.id,
- max_wait_time=3 * 60, # often takes 2-3 minutes
- expected_statuses={"completed"},
- timeout_action="skip",
- )
-
- # Expecting a completed batch with both successful and failed requests
- # Batch(id='batch_xxx',
- # completion_window='24h',
- # created_at=...,
- # endpoint='/v1/chat/completions',
- # input_file_id='file-xxx',
- # object='batch',
- # status='completed',
- # output_file_id='file-xxx',
- # error_file_id='file-xxx',
- # request_counts=BatchRequestCounts(completed=1, failed=1, total=2))
-
- assert final_batch.status == "completed"
- assert final_batch.request_counts is not None
- assert final_batch.request_counts.total == 2
- assert final_batch.request_counts.completed == 1
- assert final_batch.request_counts.failed == 1
-
- assert final_batch.output_file_id is not None, "Output file should exist for successful requests"
-
- output_content = openai_client.files.content(final_batch.output_file_id)
- if isinstance(output_content, str):
- output_text = output_content
- else:
- output_text = output_content.content.decode("utf-8")
-
- output_lines = output_text.strip().split("\n")
-
- for line in output_lines:
- result = json.loads(line)
-
- assert "id" in result
- assert "custom_id" in result
- assert result["custom_id"] == "success-1"
-
- assert "response" in result
-
- assert result["response"]["status_code"] == 200
- assert "body" in result["response"]
- assert "choices" in result["response"]["body"]
-
- assert final_batch.error_file_id is not None, "Error file should exist for failed requests"
-
- error_content = openai_client.files.content(final_batch.error_file_id)
- if isinstance(error_content, str):
- error_text = error_content
- else:
- error_text = error_content.content.decode("utf-8")
-
- error_lines = error_text.strip().split("\n")
-
- for line in error_lines:
- result = json.loads(line)
-
- assert "id" in result
- assert "custom_id" in result
- assert result["custom_id"] == "error-1"
- assert "error" in result
- error = result["error"]
- assert error is not None
- assert "code" in error or "message" in error, "Error should have code or message"
-
- deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
- assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully"
-
- deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
- assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py
deleted file mode 100644
index bc94a182e..000000000
--- a/tests/integration/batches/test_batches_errors.py
+++ /dev/null
@@ -1,693 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-"""
-Error handling and edge case tests for the Llama Stack batch processing functionality.
-
-This module focuses exclusively on testing error conditions, validation failures,
-and edge cases for batch operations to ensure robust error handling and graceful
-degradation.
-
-Test Categories:
- 1. File and Input Validation:
- - test_batch_nonexistent_file_id: Handling invalid file IDs
- - test_batch_malformed_jsonl: Processing malformed JSONL input files
- - test_file_malformed_batch_file: Handling malformed files at upload time
- - test_batch_missing_required_fields: Validation of required request fields
-
- 2. API Endpoint and Model Validation:
- - test_batch_invalid_endpoint: Invalid endpoint handling during creation
- - test_batch_error_handling_invalid_model: Error handling with nonexistent models
- - test_batch_endpoint_mismatch: Validation of endpoint/URL consistency
-
- 3. Batch Lifecycle Error Handling:
- - test_batch_retrieve_nonexistent: Retrieving non-existent batches
- - test_batch_cancel_nonexistent: Cancelling non-existent batches
- - test_batch_cancel_completed: Attempting to cancel completed batches
-
- 4. Parameter and Configuration Validation:
- - test_batch_invalid_completion_window: Invalid completion window values
- - test_batch_invalid_metadata_types: Invalid metadata type validation
- - test_batch_missing_required_body_fields: Validation of required fields in request body
-
- 5. Feature Restriction and Compatibility:
- - test_batch_streaming_not_supported: Streaming request rejection
- - test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation
-
-Note: Core functionality and OpenAI compatibility tests are located in
-test_batches_integration.py for better organization and separation of concerns.
-
-CLEANUP WARNING: These tests create batches to test error conditions but do not
-automatically clean them up after test completion. While most error tests create
-batches that fail quickly, some may create valid batches that consume resources.
-"""
-
-import pytest
-from openai import BadRequestError, ConflictError, NotFoundError
-
-
-class TestBatchesErrorHandling:
- """Error handling and edge case tests for the batches API using OpenAI client."""
-
- def test_batch_nonexistent_file_id(self, openai_client, batch_helper):
- """Test batch creation with nonexistent input file ID."""
-
- batch = openai_client.batches.create(
- input_file_id="file-nonexistent-xyz",
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(data=[
- # BatchError(
- # code='invalid_request',
- # line=None,
- # message='Cannot find file ..., or organization ... does not have access to it.',
- # param='file_id')
- # ], object='list'),
- # failed_at=1754566971,
- # ...)
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) == 1
- error = final_batch.errors.data[0]
- assert error.code == "invalid_request"
- assert "cannot find file" in error.message.lower()
-
- def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id):
- """Test batch creation with invalid endpoint."""
- batch_requests = [
- {
- "custom_id": "invalid-endpoint",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello"}],
- "max_tokens": 10,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests) as uploaded_file:
- with pytest.raises(BadRequestError) as exc_info:
- openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/invalid/endpoint",
- completion_window="24h",
- )
-
- # Expected -
- # Error code: 400 - {
- # 'error': {
- # 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.",
- # 'type': 'invalid_request_error',
- # 'param': 'endpoint',
- # 'code': 'invalid_value'
- # }
- # }
-
- error_msg = str(exc_info.value).lower()
- assert exc_info.value.status_code == 400
- assert "invalid value" in error_msg
- assert "/v1/invalid/endpoint" in error_msg
- assert "supported values" in error_msg
- assert "endpoint" in error_msg
- assert "invalid_value" in error_msg
-
- def test_batch_malformed_jsonl(self, openai_client, batch_helper):
- """
- Test batch with malformed JSONL input.
-
- The /v1/files endpoint requires valid JSONL format, so we provide a well formed line
- before a malformed line to ensure we get to the /v1/batches validation stage.
- """
- with batch_helper.create_file(
- """{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}}
-{invalid json here""",
- "malformed_batch_input.jsonl",
- ) as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(data=[
- # ...,
- # BatchError(code='invalid_json_line',
- # line=2,
- # message='This line is not parseable as valid JSON.',
- # param=None)
- # ], object='list'),
- # ...)
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) > 0
- error = final_batch.errors.data[-1] # get last error because first may be about the "test" model
- assert error.code == "invalid_json_line"
- assert error.line == 2
- assert "not" in error.message.lower()
- assert "valid json" in error.message.lower()
-
- @pytest.mark.xfail(reason="Not all file providers validate content")
- @pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"])
- def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests):
- """Test file upload with malformed content."""
-
- with pytest.raises(BadRequestError) as exc_info:
- with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"):
- # /v1/files rejects the file, we don't get to batch creation
- pass
-
- error_msg = str(exc_info.value).lower()
- assert exc_info.value.status_code == 400
- assert "invalid file format" in error_msg
- assert "jsonl" in error_msg
-
- def test_batch_retrieve_nonexistent(self, openai_client):
- """Test retrieving nonexistent batch."""
- with pytest.raises(NotFoundError) as exc_info:
- openai_client.batches.retrieve("batch-nonexistent-xyz")
-
- error_msg = str(exc_info.value).lower()
- assert exc_info.value.status_code == 404
- assert "no batch found" in error_msg or "not found" in error_msg
-
- def test_batch_cancel_nonexistent(self, openai_client):
- """Test cancelling nonexistent batch."""
- with pytest.raises(NotFoundError) as exc_info:
- openai_client.batches.cancel("batch-nonexistent-xyz")
-
- error_msg = str(exc_info.value).lower()
- assert exc_info.value.status_code == 404
- assert "no batch found" in error_msg or "not found" in error_msg
-
- def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id):
- """Test cancelling already completed batch."""
- batch_requests = [
- {
- "custom_id": "cancel-completed",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Quick test"}],
- "max_tokens": 5,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(
- batch.id,
- max_wait_time=3 * 60, # often take 10-11 min, give it 3 min
- expected_statuses={"completed"},
- timeout_action="skip",
- )
-
- deleted_file = openai_client.files.delete(final_batch.output_file_id)
- assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully"
-
- with pytest.raises(ConflictError) as exc_info:
- openai_client.batches.cancel(batch.id)
-
- # Expecting -
- # Error code: 409 - {
- # 'error': {
- # 'message': "Cannot cancel a batch with status 'completed'.",
- # 'type': 'invalid_request_error',
- # 'param': None,
- # 'code': None
- # }
- # }
- #
- # NOTE: Same for "failed", cancelling "cancelled" batches is allowed
-
- error_msg = str(exc_info.value).lower()
- assert exc_info.value.status_code == 409
- assert "cannot cancel" in error_msg
-
- def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id):
- """Test batch with requests missing required fields."""
- batch_requests = [
- {
- # Missing custom_id
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "No custom_id"}],
- "max_tokens": 10,
- },
- },
- {
- "custom_id": "no-method",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "No method"}],
- "max_tokens": 10,
- },
- },
- {
- "custom_id": "no-url",
- "method": "POST",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "No URL"}],
- "max_tokens": 10,
- },
- },
- {
- "custom_id": "no-body",
- "method": "POST",
- "url": "/v1/chat/completions",
- },
- ]
-
- with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(
- # data=[
- # BatchError(
- # code='missing_required_parameter',
- # line=1,
- # message="Missing required parameter: 'custom_id'.",
- # param='custom_id'
- # ),
- # BatchError(
- # code='missing_required_parameter',
- # line=2,
- # message="Missing required parameter: 'method'.",
- # param='method'
- # ),
- # BatchError(
- # code='missing_required_parameter',
- # line=3,
- # message="Missing required parameter: 'url'.",
- # param='url'
- # ),
- # BatchError(
- # code='missing_required_parameter',
- # line=4,
- # message="Missing required parameter: 'body'.",
- # param='body'
- # )
- # ], object='list'),
- # failed_at=1754566945,
- # ...)
- # )
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) == 4
- no_custom_id_error = final_batch.errors.data[0]
- assert no_custom_id_error.code == "missing_required_parameter"
- assert no_custom_id_error.line == 1
- assert "missing" in no_custom_id_error.message.lower()
- assert "custom_id" in no_custom_id_error.message.lower()
- no_method_error = final_batch.errors.data[1]
- assert no_method_error.code == "missing_required_parameter"
- assert no_method_error.line == 2
- assert "missing" in no_method_error.message.lower()
- assert "method" in no_method_error.message.lower()
- no_url_error = final_batch.errors.data[2]
- assert no_url_error.code == "missing_required_parameter"
- assert no_url_error.line == 3
- assert "missing" in no_url_error.message.lower()
- assert "url" in no_url_error.message.lower()
- no_body_error = final_batch.errors.data[3]
- assert no_body_error.code == "missing_required_parameter"
- assert no_body_error.line == 4
- assert "missing" in no_body_error.message.lower()
- assert "body" in no_body_error.message.lower()
-
- def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id):
- """Test batch creation with invalid completion window."""
- batch_requests = [
- {
- "custom_id": "invalid-completion-window",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello"}],
- "max_tokens": 10,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests) as uploaded_file:
- for window in ["1h", "48h", "invalid", ""]:
- with pytest.raises(BadRequestError) as exc_info:
- openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window=window,
- )
- assert exc_info.value.status_code == 400
- error_msg = str(exc_info.value).lower()
- assert "error" in error_msg
- assert "completion_window" in error_msg
-
- def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id):
- """Test that streaming responses are not supported in batches."""
- batch_requests = [
- {
- "custom_id": "streaming-test",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello"}],
- "max_tokens": 10,
- "stream": True, # Not supported
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(data=[
- # BatchError(code='streaming_unsupported',
- # line=1,
- # message='Chat Completions: Streaming is not supported in the Batch API.',
- # param='body.stream')
- # ], object='list'),
- # failed_at=1754566965,
- # ...)
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) == 1
- error = final_batch.errors.data[0]
- assert error.code == "streaming_unsupported"
- assert error.line == 1
- assert "streaming" in error.message.lower()
- assert "not supported" in error.message.lower()
- assert error.param == "body.stream"
- assert final_batch.failed_at is not None
-
- def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id):
- """
- Test batch with mixed streaming and non-streaming requests.
-
- This is distinct from test_batch_streaming_not_supported, which tests a single
- streaming request, to ensure an otherwise valid batch fails when a single
- streaming request is included.
- """
- batch_requests = [
- {
- "custom_id": "valid-non-streaming-request",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello without streaming"}],
- "max_tokens": 10,
- },
- },
- {
- "custom_id": "streaming-request",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello with streaming"}],
- "max_tokens": 10,
- "stream": True, # Not supported
- },
- },
- ]
-
- with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(data=[
- # BatchError(
- # code='streaming_unsupported',
- # line=2,
- # message='Chat Completions: Streaming is not supported in the Batch API.',
- # param='body.stream')
- # ], object='list'),
- # failed_at=1754574442,
- # ...)
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) == 1
- error = final_batch.errors.data[0]
- assert error.code == "streaming_unsupported"
- assert error.line == 2
- assert "streaming" in error.message.lower()
- assert "not supported" in error.message.lower()
- assert error.param == "body.stream"
- assert final_batch.failed_at is not None
-
- def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id):
- """Test batch creation with mismatched endpoint and request URL."""
- batch_requests = [
- {
- "custom_id": "endpoint-mismatch",
- "method": "POST",
- "url": "/v1/embeddings", # Different from batch endpoint
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello"}],
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions", # Different from request URL
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(data=[
- # BatchError(
- # code='invalid_url',
- # line=1,
- # message='The URL provided for this request does not match the batch endpoint.',
- # param='url')
- # ], object='list'),
- # failed_at=1754566972,
- # ...)
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) == 1
- error = final_batch.errors.data[0]
- assert error.line == 1
- assert error.code == "invalid_url"
- assert "does not match" in error.message.lower()
- assert "endpoint" in error.message.lower()
- assert final_batch.failed_at is not None
-
- def test_batch_error_handling_invalid_model(self, openai_client, batch_helper):
- """Test batch error handling with invalid model."""
- batch_requests = [
- {
- "custom_id": "invalid-model",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": "nonexistent-model-xyz",
- "messages": [{"role": "user", "content": "Hello"}],
- "max_tokens": 10,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests) as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(data=[
- # BatchError(code='model_not_found',
- # line=1,
- # message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.",
- # param='body.model')
- # ], object='list'),
- # failed_at=1754566978,
- # ...)
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) == 1
- error = final_batch.errors.data[0]
- assert error.line == 1
- assert error.code == "model_not_found"
- assert "not supported" in error.message.lower()
- assert error.param == "body.model"
- assert final_batch.failed_at is not None
-
- def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id):
- """Test batch with requests missing required fields in body (model and messages)."""
- batch_requests = [
- {
- "custom_id": "missing-model",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- # Missing model field
- "messages": [{"role": "user", "content": "Hello without model"}],
- "max_tokens": 10,
- },
- },
- {
- "custom_id": "missing-messages",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- # Missing messages field
- "max_tokens": 10,
- },
- },
- ]
-
- with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file:
- batch = openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- )
-
- final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"})
-
- # Expecting -
- # Batch(...,
- # status='failed',
- # errors=Errors(data=[
- # BatchError(
- # code='invalid_request',
- # line=1,
- # message='Model parameter is required.',
- # param='body.model'),
- # BatchError(
- # code='invalid_request',
- # line=2,
- # message='Messages parameter is required.',
- # param='body.messages')
- # ], object='list'),
- # ...)
-
- assert final_batch.status == "failed"
- assert final_batch.errors is not None
- assert len(final_batch.errors.data) == 2
-
- model_error = final_batch.errors.data[0]
- assert model_error.line == 1
- assert "model" in model_error.message.lower()
- assert model_error.param == "body.model"
-
- messages_error = final_batch.errors.data[1]
- assert messages_error.line == 2
- assert "messages" in messages_error.message.lower()
- assert messages_error.param == "body.messages"
-
- assert final_batch.failed_at is not None
-
- def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id):
- """Test batch creation with invalid metadata types (like lists)."""
- batch_requests = [
- {
- "custom_id": "invalid-metadata-type",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {
- "model": text_model_id,
- "messages": [{"role": "user", "content": "Hello"}],
- "max_tokens": 10,
- },
- }
- ]
-
- with batch_helper.create_file(batch_requests) as uploaded_file:
- with pytest.raises(Exception) as exc_info:
- openai_client.batches.create(
- input_file_id=uploaded_file.id,
- endpoint="/v1/chat/completions",
- completion_window="24h",
- metadata={
- "tags": ["tag1", "tag2"], # Invalid type, should be a string
- },
- )
-
- # Expecting -
- # Error code: 400 - {'error':
- # {'message': "Invalid type for 'metadata.tags': expected a string,
- # but got an array instead.",
- # 'type': 'invalid_request_error', 'param': 'metadata.tags',
- # 'code': 'invalid_type'}}
-
- error_msg = str(exc_info.value).lower()
- assert "400" in error_msg
- assert "tags" in error_msg
- assert "string" in error_msg
diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py
deleted file mode 100644
index 9fe0cc710..000000000
--- a/tests/unit/providers/batches/test_reference.py
+++ /dev/null
@@ -1,753 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the terms described in the LICENSE file in
-# the root directory of this source tree.
-
-"""
-Test suite for the reference implementation of the Batches API.
-
-The tests are categorized and outlined below, keep this updated:
-
-- Batch creation with various parameters and validation:
- * test_create_and_retrieve_batch_success (positive)
- * test_create_batch_without_metadata (positive)
- * test_create_batch_completion_window (negative)
- * test_create_batch_invalid_endpoints (negative)
- * test_create_batch_invalid_metadata (negative)
-
-- Batch retrieval and error handling for non-existent batches:
- * test_retrieve_batch_not_found (negative)
-
-- Batch cancellation with proper status transitions:
- * test_cancel_batch_success (positive)
- * test_cancel_batch_invalid_statuses (negative)
- * test_cancel_batch_not_found (negative)
-
-- Batch listing with pagination and filtering:
- * test_list_batches_empty (positive)
- * test_list_batches_single_batch (positive)
- * test_list_batches_multiple_batches (positive)
- * test_list_batches_with_limit (positive)
- * test_list_batches_with_pagination (positive)
- * test_list_batches_invalid_after (negative)
-
-- Data persistence in the underlying key-value store:
- * test_kvstore_persistence (positive)
-
-- Batch processing concurrency control:
- * test_max_concurrent_batches (positive)
-
-- Input validation testing (direct _validate_input method tests):
- * test_validate_input_file_not_found (negative)
- * test_validate_input_file_exists_empty_content (positive)
- * test_validate_input_file_mixed_valid_invalid_json (mixed)
- * test_validate_input_invalid_model (negative)
- * test_validate_input_url_mismatch (negative)
- * test_validate_input_multiple_errors_per_request (negative)
- * test_validate_input_invalid_request_format (negative)
- * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
- * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
-
-The tests use temporary SQLite databases for isolation and mock external
-dependencies like inference, files, and models APIs.
-"""
-
-import json
-import tempfile
-from pathlib import Path
-from unittest.mock import AsyncMock, MagicMock
-
-import pytest
-
-from llama_stack.apis.batches import BatchObject
-from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
-from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
-from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
-from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
-
-
-class TestReferenceBatchesImpl:
- """Test the reference implementation of the Batches API."""
-
- @pytest.fixture
- async def provider(self):
- """Create a test provider instance with temporary database."""
- with tempfile.TemporaryDirectory() as tmpdir:
- db_path = Path(tmpdir) / "test_batches.db"
- kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
- config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
-
- # Create kvstore and mock APIs
- from unittest.mock import AsyncMock
-
- from llama_stack.providers.utils.kvstore import kvstore_impl
-
- kvstore = await kvstore_impl(config.kvstore)
- mock_inference = AsyncMock()
- mock_files = AsyncMock()
- mock_models = AsyncMock()
-
- provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
- await provider.initialize()
-
- # unit tests should not require background processing
- provider.process_batches = False
-
- yield provider
-
- await provider.shutdown()
-
- @pytest.fixture
- def sample_batch_data(self):
- """Sample batch data for testing."""
- return {
- "input_file_id": "file_abc123",
- "endpoint": "/v1/chat/completions",
- "completion_window": "24h",
- "metadata": {"test": "true", "priority": "high"},
- }
-
- def _validate_batch_type(self, batch, expected_metadata=None):
- """
- Helper function to validate batch object structure and field types.
-
- Note: This validates the direct BatchObject from the provider, not the
- client library response which has a different structure.
-
- Args:
- batch: The BatchObject instance to validate.
- expected_metadata: Optional expected metadata dictionary to validate against.
- """
- assert isinstance(batch.id, str)
- assert isinstance(batch.completion_window, str)
- assert isinstance(batch.created_at, int)
- assert isinstance(batch.endpoint, str)
- assert isinstance(batch.input_file_id, str)
- assert batch.object == "batch"
- assert batch.status in [
- "validating",
- "failed",
- "in_progress",
- "finalizing",
- "completed",
- "expired",
- "cancelling",
- "cancelled",
- ]
-
- if expected_metadata is not None:
- assert batch.metadata == expected_metadata
-
- timestamp_fields = [
- "cancelled_at",
- "cancelling_at",
- "completed_at",
- "expired_at",
- "expires_at",
- "failed_at",
- "finalizing_at",
- "in_progress_at",
- ]
- for field in timestamp_fields:
- field_value = getattr(batch, field, None)
- if field_value is not None:
- assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
-
- file_id_fields = ["error_file_id", "output_file_id"]
- for field in file_id_fields:
- field_value = getattr(batch, field, None)
- if field_value is not None:
- assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
-
- if hasattr(batch, "request_counts") and batch.request_counts is not None:
- assert isinstance(batch.request_counts.completed, int), (
- f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
- )
- assert isinstance(batch.request_counts.failed, int), (
- f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
- )
- assert isinstance(batch.request_counts.total, int), (
- f"request_counts.total should be int, got {type(batch.request_counts.total)}"
- )
-
- if hasattr(batch, "errors") and batch.errors is not None:
- assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
-
- if hasattr(batch.errors, "data") and batch.errors.data is not None:
- assert isinstance(batch.errors.data, list), (
- f"errors.data should be list or None, got {type(batch.errors.data)}"
- )
-
- for i, error_item in enumerate(batch.errors.data):
- assert isinstance(error_item, dict), (
- f"errors.data[{i}] should be object or dict, got {type(error_item)}"
- )
-
- if hasattr(error_item, "code") and error_item.code is not None:
- assert isinstance(error_item.code, str), (
- f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
- )
-
- if hasattr(error_item, "line") and error_item.line is not None:
- assert isinstance(error_item.line, int), (
- f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
- )
-
- if hasattr(error_item, "message") and error_item.message is not None:
- assert isinstance(error_item.message, str), (
- f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
- )
-
- if hasattr(error_item, "param") and error_item.param is not None:
- assert isinstance(error_item.param, str), (
- f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
- )
-
- if hasattr(batch.errors, "object") and batch.errors.object is not None:
- assert isinstance(batch.errors.object, str), (
- f"errors.object should be str or None, got {type(batch.errors.object)}"
- )
- assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
-
- async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
- """Test successful batch creation and retrieval."""
- created_batch = await provider.create_batch(**sample_batch_data)
-
- self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
-
- assert created_batch.id.startswith("batch_")
- assert len(created_batch.id) > 13
- assert created_batch.object == "batch"
- assert created_batch.endpoint == sample_batch_data["endpoint"]
- assert created_batch.input_file_id == sample_batch_data["input_file_id"]
- assert created_batch.completion_window == sample_batch_data["completion_window"]
- assert created_batch.status == "validating"
- assert created_batch.metadata == sample_batch_data["metadata"]
- assert isinstance(created_batch.created_at, int)
- assert created_batch.created_at > 0
-
- retrieved_batch = await provider.retrieve_batch(created_batch.id)
-
- self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
-
- assert retrieved_batch.id == created_batch.id
- assert retrieved_batch.input_file_id == created_batch.input_file_id
- assert retrieved_batch.endpoint == created_batch.endpoint
- assert retrieved_batch.status == created_batch.status
- assert retrieved_batch.metadata == created_batch.metadata
-
- async def test_create_batch_without_metadata(self, provider):
- """Test batch creation without optional metadata."""
- batch = await provider.create_batch(
- input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
- )
-
- assert batch.metadata is None
-
- async def test_create_batch_completion_window(self, provider):
- """Test batch creation with invalid completion window."""
- with pytest.raises(ValueError, match="Invalid completion_window"):
- await provider.create_batch(
- input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
- )
-
- @pytest.mark.parametrize(
- "endpoint",
- [
- "/v1/embeddings",
- "/v1/completions",
- "/v1/invalid/endpoint",
- "",
- ],
- )
- async def test_create_batch_invalid_endpoints(self, provider, endpoint):
- """Test batch creation with various invalid endpoints."""
- with pytest.raises(ValueError, match="Invalid endpoint"):
- await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
-
- async def test_create_batch_invalid_metadata(self, provider):
- """Test that batch creation fails with invalid metadata."""
- with pytest.raises(ValueError, match="should be a valid string"):
- await provider.create_batch(
- input_file_id="file_123",
- endpoint="/v1/chat/completions",
- completion_window="24h",
- metadata={123: "invalid_key"}, # Non-string key
- )
-
- with pytest.raises(ValueError, match="should be a valid string"):
- await provider.create_batch(
- input_file_id="file_123",
- endpoint="/v1/chat/completions",
- completion_window="24h",
- metadata={"valid_key": 456}, # Non-string value
- )
-
- async def test_retrieve_batch_not_found(self, provider):
- """Test error when retrieving non-existent batch."""
- with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
- await provider.retrieve_batch("nonexistent_batch")
-
- async def test_cancel_batch_success(self, provider, sample_batch_data):
- """Test successful batch cancellation."""
- created_batch = await provider.create_batch(**sample_batch_data)
- assert created_batch.status == "validating"
-
- cancelled_batch = await provider.cancel_batch(created_batch.id)
-
- assert cancelled_batch.id == created_batch.id
- assert cancelled_batch.status in ["cancelling", "cancelled"]
- assert isinstance(cancelled_batch.cancelling_at, int)
- assert cancelled_batch.cancelling_at >= created_batch.created_at
-
- @pytest.mark.parametrize("status", ["failed", "expired", "completed"])
- async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
- """Test error when cancelling batch in final states."""
- provider.process_batches = False
- created_batch = await provider.create_batch(**sample_batch_data)
-
- # directly update status in kvstore
- await provider._update_batch(created_batch.id, status=status)
-
- with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
- await provider.cancel_batch(created_batch.id)
-
- async def test_cancel_batch_not_found(self, provider):
- """Test error when cancelling non-existent batch."""
- with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
- await provider.cancel_batch("nonexistent_batch")
-
- async def test_list_batches_empty(self, provider):
- """Test listing batches when none exist."""
- response = await provider.list_batches()
-
- assert response.object == "list"
- assert response.data == []
- assert response.first_id is None
- assert response.last_id is None
- assert response.has_more is False
-
- async def test_list_batches_single_batch(self, provider, sample_batch_data):
- """Test listing batches with single batch."""
- created_batch = await provider.create_batch(**sample_batch_data)
-
- response = await provider.list_batches()
-
- assert len(response.data) == 1
- self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
- assert response.data[0].id == created_batch.id
- assert response.first_id == created_batch.id
- assert response.last_id == created_batch.id
- assert response.has_more is False
-
- async def test_list_batches_multiple_batches(self, provider):
- """Test listing multiple batches."""
- batches = [
- await provider.create_batch(
- input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
- )
- for i in range(3)
- ]
-
- response = await provider.list_batches()
-
- assert len(response.data) == 3
-
- batch_ids = {batch.id for batch in response.data}
- expected_ids = {batch.id for batch in batches}
- assert batch_ids == expected_ids
- assert response.has_more is False
-
- assert response.first_id in expected_ids
- assert response.last_id in expected_ids
-
- async def test_list_batches_with_limit(self, provider):
- """Test listing batches with limit parameter."""
- batches = [
- await provider.create_batch(
- input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
- )
- for i in range(3)
- ]
-
- response = await provider.list_batches(limit=2)
-
- assert len(response.data) == 2
- assert response.has_more is True
- assert response.first_id == response.data[0].id
- assert response.last_id == response.data[1].id
- batch_ids = {batch.id for batch in response.data}
- expected_ids = {batch.id for batch in batches}
- assert batch_ids.issubset(expected_ids)
-
- async def test_list_batches_with_pagination(self, provider):
- """Test listing batches with pagination using 'after' parameter."""
- for i in range(3):
- await provider.create_batch(
- input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
- )
-
- # Get first page
- first_page = await provider.list_batches(limit=1)
- assert len(first_page.data) == 1
- assert first_page.has_more is True
-
- # Get second page using 'after'
- second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
- assert len(second_page.data) == 1
- assert second_page.data[0].id != first_page.data[0].id
-
- # Verify we got the next batch in order
- all_batches = await provider.list_batches()
- expected_second_batch_id = all_batches.data[1].id
- assert second_page.data[0].id == expected_second_batch_id
-
- async def test_list_batches_invalid_after(self, provider, sample_batch_data):
- """Test listing batches with invalid 'after' parameter."""
- await provider.create_batch(**sample_batch_data)
-
- response = await provider.list_batches(after="nonexistent_batch")
-
- # Should return all batches (no filtering when 'after' batch not found)
- assert len(response.data) == 1
-
- async def test_kvstore_persistence(self, provider, sample_batch_data):
- """Test that batches are properly persisted in kvstore."""
- batch = await provider.create_batch(**sample_batch_data)
-
- stored_data = await provider.kvstore.get(f"batch:{batch.id}")
- assert stored_data is not None
-
- stored_batch_dict = json.loads(stored_data)
- assert stored_batch_dict["id"] == batch.id
- assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
-
- async def test_validate_input_file_not_found(self, provider):
- """Test _validate_input when input file does not exist."""
- provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions",
- input_file_id="nonexistent_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) == 1
- assert len(requests) == 0
- assert errors[0].code == "invalid_request"
- assert errors[0].message == "Cannot find file nonexistent_file."
- assert errors[0].param == "input_file_id"
- assert errors[0].line is None
-
- async def test_validate_input_file_exists_empty_content(self, provider):
- """Test _validate_input when file exists but is empty."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
- mock_response.body = b""
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions",
- input_file_id="empty_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) == 0
- assert len(requests) == 0
-
- async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
- """Test _validate_input when file contains valid and invalid JSON lines."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
- # Line 1: valid JSON with proper body args, Line 2: invalid JSON
- mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions",
- input_file_id="mixed_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- # Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
- assert len(errors) == 1
- assert len(requests) == 1
-
- assert errors[0].code == "invalid_json_line"
- assert errors[0].line == 2
- assert errors[0].message == "This line is not parseable as valid JSON."
-
- assert requests[0].custom_id == "req-1"
- assert requests[0].method == "POST"
- assert requests[0].url == "/v1/chat/completions"
- assert requests[0].body["model"] == "test-model"
- assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
-
- async def test_validate_input_invalid_model(self, provider):
- """Test _validate_input when file contains request with non-existent model."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
- mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions",
- input_file_id="invalid_model_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) == 1
- assert len(requests) == 0
-
- assert errors[0].code == "model_not_found"
- assert errors[0].line == 1
- assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
- assert errors[0].param == "body.model"
-
- @pytest.mark.parametrize(
- "param_name,param_path,error_code,error_message",
- [
- ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
- ("method", "method", "missing_required_parameter", "Missing required parameter: method"),
- ("url", "url", "missing_required_parameter", "Missing required parameter: url"),
- ("body", "body", "missing_required_parameter", "Missing required parameter: body"),
- ("model", "body.model", "invalid_request", "Model parameter is required"),
- ("messages", "body.messages", "invalid_request", "Messages parameter is required"),
- ],
- )
- async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
- """Test _validate_input when file contains request with missing required parameters."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
-
- base_request = {
- "custom_id": "req-1",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
- }
-
- # Remove the specific parameter being tested
- if "." in param_path:
- top_level, nested_param = param_path.split(".", 1)
- del base_request[top_level][nested_param]
- else:
- del base_request[param_name]
-
- mock_response.body = json.dumps(base_request).encode()
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions",
- input_file_id=f"missing_{param_name}_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) == 1
- assert len(requests) == 0
-
- assert errors[0].code == error_code
- assert errors[0].line == 1
- assert errors[0].message == error_message
- assert errors[0].param == param_path
-
- async def test_validate_input_url_mismatch(self, provider):
- """Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
- mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions", # This doesn't match the URL in the request
- input_file_id="url_mismatch_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) == 1
- assert len(requests) == 0
-
- assert errors[0].code == "invalid_url"
- assert errors[0].line == 1
- assert errors[0].message == "URL provided for this request does not match the batch endpoint"
- assert errors[0].param == "url"
-
- async def test_validate_input_multiple_errors_per_request(self, provider):
- """Test _validate_input when a single request has multiple validation errors."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
- # Request missing custom_id, has invalid URL, and missing model in body
- mock_response.body = (
- b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
- )
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
- input_file_id="multiple_errors_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) >= 2 # At least missing custom_id and URL mismatch
- assert len(requests) == 0
-
- for error in errors:
- assert error.line == 1
-
- error_codes = {error.code for error in errors}
- assert "missing_required_parameter" in error_codes # missing custom_id
- assert "invalid_url" in error_codes # URL mismatch
-
- async def test_validate_input_invalid_request_format(self, provider):
- """Test _validate_input when file contains non-object JSON (array, string, number)."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
- mock_response.body = b'["not", "a", "request", "object"]'
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions",
- input_file_id="invalid_format_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) == 1
- assert len(requests) == 0
-
- assert errors[0].code == "invalid_request"
- assert errors[0].line == 1
- assert errors[0].message == "Each line must be a JSON dictionary object"
-
- @pytest.mark.parametrize(
- "param_name,param_path,invalid_value,error_message",
- [
- ("custom_id", "custom_id", 12345, "Custom_id must be a string"),
- ("url", "url", 123, "URL must be a string"),
- ("method", "method", ["POST"], "Method must be a string"),
- ("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
- ("model", "body.model", 123, "Model must be a string"),
- ("messages", "body.messages", "invalid messages format", "Messages must be an array"),
- ],
- )
- async def test_validate_input_invalid_parameter_types(
- self, provider, param_name, param_path, invalid_value, error_message
- ):
- """Test _validate_input when file contains request with parameters that have invalid types."""
- provider.files_api.openai_retrieve_file = AsyncMock()
- mock_response = MagicMock()
-
- base_request = {
- "custom_id": "req-1",
- "method": "POST",
- "url": "/v1/chat/completions",
- "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
- }
-
- # Override the specific parameter with invalid value
- if "." in param_path:
- top_level, nested_param = param_path.split(".", 1)
- base_request[top_level][nested_param] = invalid_value
- else:
- base_request[param_name] = invalid_value
-
- mock_response.body = json.dumps(base_request).encode()
- provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
-
- batch = BatchObject(
- id="batch_test",
- object="batch",
- endpoint="/v1/chat/completions",
- input_file_id=f"invalid_{param_name}_type_file",
- completion_window="24h",
- status="validating",
- created_at=1234567890,
- )
-
- errors, requests = await provider._validate_input(batch)
-
- assert len(errors) == 1
- assert len(requests) == 0
-
- assert errors[0].code == "invalid_request"
- assert errors[0].line == 1
- assert errors[0].message == error_message
- assert errors[0].param == param_path
-
- async def test_max_concurrent_batches(self, provider):
- """Test max_concurrent_batches configuration and concurrency control."""
- import asyncio
-
- provider._batch_semaphore = asyncio.Semaphore(2)
-
- provider.process_batches = True # enable because we're testing background processing
-
- active_batches = 0
-
- async def add_and_wait(batch_id: str):
- nonlocal active_batches
- active_batches += 1
- await asyncio.sleep(float("inf"))
-
- # the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
- # so we can replace _process_batch_impl with our mock to control concurrency
- provider._process_batch_impl = add_and_wait
-
- for _ in range(3):
- await provider.create_batch(
- input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
- )
-
- await asyncio.sleep(0.042) # let tasks start
-
- assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"