From 8e678912ec3af83222bd4abad0b532e772601d6e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 8 Aug 2025 08:08:08 -0400 Subject: [PATCH 1/3] feat: add batches API with OpenAI compatibility Add complete batches API implementation with protocol, providers, and tests: Core Infrastructure: - Add batches API protocol using OpenAI Batch types directly - Add Api.batches enum value and protocol mapping in resolver - Add OpenAI "batch" file purpose support - Include proper error handling (ConflictError, ResourceNotFoundError) Reference Provider: - Add ReferenceBatchesImpl with full CRUD operations (create, retrieve, cancel, list) - Implement background batch processing with configurable concurrency - Add SQLite KVStore backend for persistence - Support /v1/chat/completions endpoint with request validation Comprehensive Test Suite: - Add unit tests for provider implementation with validation - Add integration tests for end-to-end batch processing workflows - Add error handling tests for validation, malformed inputs, and edge cases Configuration: - Add max_concurrent_batches and max_concurrent_requests_per_batch options - Add provider documentation with sample configurations Test with - ``` $ uv run llama stack build --image-type venv --providers inference=YOU_PICK,files=inline::localfs,batches=inline::reference --run & $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run pytest tests/unit/providers/batches tests/integration/batches --text-model YOU_PICK ``` --- docs/_static/llama-stack-spec.html | 6 +- docs/_static/llama-stack-spec.yaml | 2 + docs/source/providers/batches/index.md | 13 + .../providers/batches/inline_reference.md | 23 + llama_stack/apis/batches/__init__.py | 9 + llama_stack/apis/batches/batches.py | 92 +++ llama_stack/apis/common/errors.py | 7 + llama_stack/apis/datatypes.py | 2 + llama_stack/apis/files/files.py | 1 + llama_stack/core/resolver.py | 2 + llama_stack/core/server/server.py | 5 + .../providers/inline/batches/__init__.py | 5 + .../inline/batches/reference/__init__.py | 36 + .../inline/batches/reference/batches.py | 553 +++++++++++++ .../inline/batches/reference/config.py | 40 + llama_stack/providers/registry/batches.py | 26 + tests/integration/batches/__init__.py | 5 + tests/integration/batches/conftest.py | 122 +++ tests/integration/batches/test_batches.py | 270 +++++++ .../batches/test_batches_errors.py | 694 ++++++++++++++++ .../unit/providers/batches/test_reference.py | 753 ++++++++++++++++++ 21 files changed, 2664 insertions(+), 2 deletions(-) create mode 100644 docs/source/providers/batches/index.md create mode 100644 docs/source/providers/batches/inline_reference.md create mode 100644 llama_stack/apis/batches/__init__.py create mode 100644 llama_stack/apis/batches/batches.py create mode 100644 llama_stack/providers/inline/batches/__init__.py create mode 100644 llama_stack/providers/inline/batches/reference/__init__.py create mode 100644 llama_stack/providers/inline/batches/reference/batches.py create mode 100644 llama_stack/providers/inline/batches/reference/config.py create mode 100644 llama_stack/providers/registry/batches.py create mode 100644 tests/integration/batches/__init__.py create mode 100644 tests/integration/batches/conftest.py create mode 100644 tests/integration/batches/test_batches.py create mode 100644 tests/integration/batches/test_batches_errors.py create mode 100644 tests/unit/providers/batches/test_reference.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index d480ff592..9896b36cd 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14591,7 +14591,8 @@ "OpenAIFilePurpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "title": "OpenAIFilePurpose", "description": "Valid purpose values for OpenAI Files API." @@ -14668,7 +14669,8 @@ "purpose": { "type": "string", "enum": [ - "assistants" + "assistants", + "batch" ], "description": "The intended purpose of the file" } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 9c0fba554..15d491a65 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10804,6 +10804,7 @@ components: type: string enum: - assistants + - batch title: OpenAIFilePurpose description: >- Valid purpose values for OpenAI Files API. @@ -10872,6 +10873,7 @@ components: type: string enum: - assistants + - batch description: The intended purpose of the file additionalProperties: false required: diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md new file mode 100644 index 000000000..d2405ecf7 --- /dev/null +++ b/docs/source/providers/batches/index.md @@ -0,0 +1,13 @@ +# Batches + +## Overview + +This section contains documentation for all available providers for the **batches** API. + +## Providers + +```{toctree} +:maxdepth: 1 + +inline_reference +``` diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md new file mode 100644 index 000000000..a58e5124d --- /dev/null +++ b/docs/source/providers/batches/inline_reference.md @@ -0,0 +1,23 @@ +# inline::reference + +## Description + +Reference implementation of batches API with KVStore persistence. + +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. | +| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. | +| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. | + +## Sample Configuration + +```yaml +kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db + +``` + diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py new file mode 100644 index 000000000..d3efe3dba --- /dev/null +++ b/llama_stack/apis/batches/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .batches import Batches, BatchObject, CreateBatchRequest, ListBatchesResponse + +__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py new file mode 100644 index 000000000..72742d4fa --- /dev/null +++ b/llama_stack/apis/batches/batches.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import BaseModel, Field + +from llama_stack.schema_utils import json_schema_type, webmethod + +try: + from openai.types import Batch as BatchObject +except ImportError as e: + raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e + + +@json_schema_type +class CreateBatchRequest(BaseModel): + """Request to create a new batch.""" + + input_file_id: str = Field(..., description="The ID of an uploaded file that contains requests for the new batch") + endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch") + completion_window: str = Field(..., description="The time window within which the batch should be processed") + metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata for the batch") + + +@json_schema_type +class ListBatchesResponse(BaseModel): + """Response containing a list of batch objects.""" + + object: Literal["list"] = "list" + data: list[BatchObject] = Field(..., description="List of batch objects") + first_id: str | None = Field(default=None, description="ID of the first batch in the list") + last_id: str | None = Field(default=None, description="ID of the last batch in the list") + has_more: bool = Field(default=False, description="Whether there are more batches available") + + +@runtime_checkable +class Batches(Protocol): + """Protocol for batch processing API operations.""" + + @webmethod(route="/openai/v1/batches", method="POST") + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: str, + metadata: dict[str, str] | None = None, + ) -> BatchObject: + """Create a new batch for processing multiple API requests. + + :param input_file_id: The ID of an uploaded file containing requests for the batch. + :param endpoint: The endpoint to be used for all requests in the batch. + :param completion_window: The time window within which the batch should be processed. + :param metadata: Optional metadata for the batch. + :returns: The created batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}", method="GET") + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch. + + :param batch_id: The ID of the batch to retrieve. + :returns: The batch object. + """ + ... + + @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST") + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress. + + :param batch_id: The ID of the batch to cancel. + :returns: The updated batch object. + """ + ... + + @webmethod(route="/openai/v1/batches", method="GET") + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """List all batches for the current user. + + :param after: A cursor for pagination; returns batches after this batch ID. + :param limit: Number of batches to return (default 20, max 100). + :returns: A list of batch objects. + """ + ... diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 95d6ac18e..c47c99f8d 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -62,3 +62,10 @@ class SessionNotFoundError(ValueError): def __init__(self, session_name: str) -> None: message = f"Session '{session_name}' not found or access denied." super().__init__(message) + + +class ConflictError(ValueError): + """raised when an operation cannot be performed due to a conflict with the current state""" + + def __init__(self, message: str) -> None: + super().__init__(message) diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index cabe46a2f..87fc95917 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -86,6 +86,7 @@ class Api(Enum, metaclass=DynamicApiMeta): :cvar inference: Text generation, chat completions, and embeddings :cvar safety: Content moderation and safety shields :cvar agents: Agent orchestration and execution + :cvar batches: Batch processing for asynchronous API requests :cvar vector_io: Vector database operations and queries :cvar datasetio: Dataset input/output operations :cvar scoring: Model output evaluation and scoring @@ -108,6 +109,7 @@ class Api(Enum, metaclass=DynamicApiMeta): inference = "inference" safety = "safety" agents = "agents" + batches = "batches" vector_io = "vector_io" datasetio = "datasetio" scoring = "scoring" diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index ba8701e23..a1b9dd4dc 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -22,6 +22,7 @@ class OpenAIFilePurpose(StrEnum): """ ASSISTANTS = "assistants" + BATCH = "batch" # TODO: Add other purposes as needed diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 70c78fb01..7ac98dac8 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -8,6 +8,7 @@ import inspect from typing import Any from llama_stack.apis.agents import Agents +from llama_stack.apis.batches import Batches from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -75,6 +76,7 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, + Api.batches: Batches, Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index fe5cc68d7..f5ef40275 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -31,6 +31,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.cli.utils import add_config_distro_args, get_config_from_args from llama_stack.core.access_control.access_control import AccessDeniedError @@ -127,6 +128,10 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ] }, ) + elif isinstance(exc, ConflictError): + return HTTPException(status_code=409, detail=str(exc)) + elif isinstance(exc, ResourceNotFoundError): + return HTTPException(status_code=404, detail=str(exc)) elif isinstance(exc, ValueError): return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/inline/batches/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py new file mode 100644 index 000000000..a8ae92eb2 --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any + +from llama_stack.apis.files import Files +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Models +from llama_stack.core.datatypes import AccessRule, Api +from llama_stack.providers.utils.kvstore import kvstore_impl + +from .batches import ReferenceBatchesImpl +from .config import ReferenceBatchesImplConfig + +__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"] + + +async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): + kvstore = await kvstore_impl(config.kvstore) + inference_api: Inference | None = deps.get(Api.inference) + files_api: Files | None = deps.get(Api.files) + models_api: Models | None = deps.get(Api.models) + + if inference_api is None: + raise ValueError("Inference API is required but not provided in dependencies") + if files_api is None: + raise ValueError("Files API is required but not provided in dependencies") + if models_api is None: + raise ValueError("Models API is required but not provided in dependencies") + + impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py new file mode 100644 index 000000000..6e99a00a1 --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -0,0 +1,553 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import itertools +import json +import time +import uuid +from io import BytesIO +from typing import Any + +from openai.types.batch import BatchError, Errors +from pydantic import BaseModel + +from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError +from llama_stack.apis.files import Files, OpenAIFilePurpose +from llama_stack.apis.inference import Inference +from llama_stack.apis.models import Models +from llama_stack.log import get_logger +from llama_stack.providers.utils.kvstore import KVStore + +from .config import ReferenceBatchesImplConfig + +BATCH_PREFIX = "batch:" + +logger = get_logger(__name__) + + +class AsyncBytesIO: + """ + Async-compatible BytesIO wrapper to allow async file-like operations. + + We use this when uploading files to the Files API, as it expects an + async file-like object. + """ + + def __init__(self, data: bytes): + self._buffer = BytesIO(data) + + async def read(self, n=-1): + return self._buffer.read(n) + + async def seek(self, pos, whence=0): + return self._buffer.seek(pos, whence) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._buffer.close() + + def __getattr__(self, name): + return getattr(self._buffer, name) + + +class BatchRequest(BaseModel): + line_num: int + custom_id: str + method: str + url: str + body: dict[str, Any] + + +class ReferenceBatchesImpl(Batches): + """Reference implementation of the Batches API. + + This implementation processes batch files by making individual requests + to the inference API and generates output files with results. + """ + + def __init__( + self, + config: ReferenceBatchesImplConfig, + inference_api: Inference, + files_api: Files, + models_api: Models, + kvstore: KVStore, + ) -> None: + self.config = config + self.kvstore = kvstore + self.inference_api = inference_api + self.files_api = files_api + self.models_api = models_api + self._processing_tasks: dict[str, asyncio.Task] = {} + self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches) + self._update_batch_lock = asyncio.Lock() + + # this is to allow tests to disable background processing + self.process_batches = True + + async def initialize(self) -> None: + # TODO: start background processing of existing tasks + pass + + async def shutdown(self) -> None: + """Shutdown the batches provider.""" + if self._processing_tasks: + # don't cancel tasks - just let them stop naturally on shutdown + # cancelling would mark batches as "cancelled" in the database + logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") + + # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: str, + metadata: dict[str, str] | None = None, + ) -> BatchObject: + """ + Create a new batch for processing multiple API requests. + + Error handling by levels - + 0. Input param handling, results in 40x errors before processing, e.g. + - Wrong completion_window + - Invalid metadata types + - Unknown endpoint + -> no batch created + 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. + - input_file_id missing + - invalid json in file + - missing custom_id, method, url, body + - invalid model + - streaming + -> batch created, validation sends to failed status + 2. Processing errors, result in error_file_id entries, e.g. + - Any error returned from inference endpoint + -> batch created, goes to completed status + """ + + # TODO: set expiration time for garbage collection + + if endpoint not in ["/v1/chat/completions"]: + raise ValueError( + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + ) + + if completion_window != "24h": + raise ValueError( + f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + ) + + batch_id = f"batch_{uuid.uuid4().hex[:16]}" + current_time = int(time.time()) + + batch = BatchObject( + id=batch_id, + object="batch", + endpoint=endpoint, + input_file_id=input_file_id, + completion_window=completion_window, + status="validating", + created_at=current_time, + metadata=metadata, + ) + + await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) + + if self.process_batches: + task = asyncio.create_task(self._process_batch(batch_id)) + self._processing_tasks[batch_id] = task + + return batch + + async def cancel_batch(self, batch_id: str) -> BatchObject: + """Cancel a batch that is in progress.""" + batch = await self.retrieve_batch(batch_id) + + if batch.status in ["cancelled", "cancelling"]: + return batch + + if batch.status in ["completed", "failed", "expired"]: + raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") + + await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) + + if batch_id in self._processing_tasks: + self._processing_tasks[batch_id].cancel() + # note: task removal and status="cancelled" handled in finally block of _process_batch + + return await self.retrieve_batch(batch_id) + + async def list_batches( + self, + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """ + List all batches, eventually only for the current user. + + With no notion of user, we return all batches. + """ + batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff") + + batches = [] + for batch_data in batch_values: + if batch_data: + batches.append(BatchObject.model_validate_json(batch_data)) + + batches.sort(key=lambda b: b.created_at, reverse=True) + + start_idx = 0 + if after: + for i, batch in enumerate(batches): + if batch.id == after: + start_idx = i + 1 + break + + page_batches = batches[start_idx : start_idx + limit] + has_more = (start_idx + limit) < len(batches) + + first_id = page_batches[0].id if page_batches else None + last_id = page_batches[-1].id if page_batches else None + + return ListBatchesResponse( + data=page_batches, + first_id=first_id, + last_id=last_id, + has_more=has_more, + ) + + async def retrieve_batch(self, batch_id: str) -> BatchObject: + """Retrieve information about a specific batch.""" + batch_data = await self.kvstore.get(f"batch:{batch_id}") + if not batch_data: + raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") + + return BatchObject.model_validate_json(batch_data) + + async def _update_batch(self, batch_id: str, **updates) -> None: + """Update batch fields in kvstore.""" + async with self._update_batch_lock: + try: + batch = await self.retrieve_batch(batch_id) + + # batch processing is async. once cancelling, only allow "cancelled" status updates + if batch.status == "cancelling" and updates.get("status") != "cancelled": + logger.info( + f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}" + ) + return + + if "errors" in updates: + updates["errors"] = updates["errors"].model_dump() + + batch_dict = batch.model_dump() + batch_dict.update(updates) + + await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict)) + except Exception as e: + logger.error(f"Failed to update batch {batch_id}: {e}") + + async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]: + """ + Read & validate input, return errors and valid input. + + Validation of + - input_file_id existance + - valid json + - custom_id, method, url, body presence and valid + - no streaming + """ + requests: list[BatchRequest] = [] + errors: list[BatchError] = [] + try: + await self.files_api.openai_retrieve_file(batch.input_file_id) + except Exception: + errors.append( + BatchError( + code="invalid_request", + line=None, + message=f"Cannot find file {batch.input_file_id}.", + param="input_file_id", + ) + ) + return errors, requests + + # TODO(SECURITY): do something about large files + file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) + file_content = file_content_response.body.decode("utf-8") + for line_num, line in enumerate(file_content.strip().split("\n"), 1): + if line.strip(): # skip empty lines + try: + request = json.loads(line) + + if not isinstance(request, dict): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message="Each line must be a JSON dictionary object", + ) + ) + continue + + valid = True + + for param, expected_type, type_string in [ + ("custom_id", str, "string"), + ("method", str, "string"), + ("url", str, "string"), + ("body", dict, "JSON dictionary object"), + ]: + if param not in request: + errors.append( + BatchError( + code="missing_required_parameter", + line=line_num, + message=f"Missing required parameter: {param}", + param=param, + ) + ) + valid = False + elif not isinstance(request[param], expected_type): + param_name = "URL" if param == "url" else param.capitalize() + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param_name} must be a {type_string}", + param=param, + ) + ) + valid = False + + if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint: + errors.append( + BatchError( + code="invalid_url", + line=line_num, + message="URL provided for this request does not match the batch endpoint", + param="url", + ) + ) + valid = False + + if (body := request.get("body")) and isinstance(body, dict): + if body.get("stream", False): + errors.append( + BatchError( + code="streaming_unsupported", + line=line_num, + message="Streaming is not supported in batch processing", + param="body.stream", + ) + ) + valid = False + + for param, expected_type, type_string in [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ]: + if param not in body: + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} parameter is required", + param=f"body.{param}", + ) + ) + valid = False + elif not isinstance(body[param], expected_type): + errors.append( + BatchError( + code="invalid_request", + line=line_num, + message=f"{param.capitalize()} must be {type_string}", + param=f"body.{param}", + ) + ) + valid = False + + if "model" in body and isinstance(body["model"], str): + try: + await self.models_api.get_model(body["model"]) + except Exception: + errors.append( + BatchError( + code="model_not_found", + line=line_num, + message=f"Model '{body['model']}' does not exist or is not supported", + param="body.model", + ) + ) + valid = False + + if valid: + assert isinstance(url, str), "URL must be a string" # for mypy + assert isinstance(body, dict), "Body must be a dictionary" # for mypy + requests.append( + BatchRequest( + line_num=line_num, + url=url, + method=request["method"], + custom_id=request["custom_id"], + body=body, + ), + ) + except json.JSONDecodeError: + errors.append( + BatchError( + code="invalid_json_line", + line=line_num, + message="This line is not parseable as valid JSON.", + ) + ) + + return errors, requests + + async def _process_batch(self, batch_id: str) -> None: + """Background task to process a batch of requests.""" + try: + logger.info(f"Starting batch processing for {batch_id}") + async with self._batch_semaphore: # semaphore to limit concurrency + logger.info(f"Acquired semaphore for batch {batch_id}") + await self._process_batch_impl(batch_id) + except asyncio.CancelledError: + logger.info(f"Batch processing cancelled for {batch_id}") + await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time())) + except Exception as e: + logger.error(f"Batch processing failed for {batch_id}: {e}") + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="internal_error", message=str(e))]), + ) + finally: + self._processing_tasks.pop(batch_id, None) + + async def _process_batch_impl(self, batch_id: str) -> None: + """Implementation of batch processing logic.""" + errors: list[BatchError] = [] + batch = await self.retrieve_batch(batch_id) + + errors, requests = await self._validate_input(batch) + if errors: + await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)) + logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors") + return + + logger.info(f"Processing {len(requests)} requests for batch {batch_id}") + + total_requests = len(requests) + await self._update_batch( + batch_id, + status="in_progress", + request_counts={"total": total_requests, "completed": 0, "failed": 0}, + ) + + error_results = [] + success_results = [] + completed_count = 0 + failed_count = 0 + + for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch): + # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled + async with asyncio.TaskGroup() as tg: + chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk] + + chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) + + for result in chunk_results: + if isinstance(result, dict) and result.get("error") is not None: # error response from inference + failed_count += 1 + error_results.append(result) + elif isinstance(result, dict) and result.get("response") is not None: # successful inference + completed_count += 1 + success_results.append(result) + else: # unexpected result + failed_count += 1 + errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}")) + + await self._update_batch( + batch_id, + request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count}, + ) + + if errors: + await self._update_batch( + batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors) + ) + return + + try: + output_file_id = await self._create_output_file(batch_id, success_results, "success") + await self._update_batch(batch_id, output_file_id=output_file_id) + + error_file_id = await self._create_output_file(batch_id, error_results, "error") + await self._update_batch(batch_id, error_file_id=error_file_id) + + await self._update_batch(batch_id, status="completed", completed_at=int(time.time())) + + logger.info( + f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed" + ) + except Exception as e: + # note: errors is empty at this point, so we don't lose anything by ignoring it + await self._update_batch( + batch_id, + status="failed", + failed_at=int(time.time()), + errors=Errors(data=[BatchError(code="output_failed", message=str(e))]), + ) + + async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict: + """Process a single request from the batch.""" + request_id = f"batch_req_{batch_id}_{request.line_num}" + + try: + # TODO(SECURITY): review body for security issues + chat_response = await self.inference_api.openai_chat_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + except Exception as e: + logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") + return { + "id": request_id, + "custom_id": request.custom_id, + "error": {"type": "request_failed", "message": str(e)}, + } + + async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str: + """ + Create an output file with batch results. + + This function filters results based on the specified file_type + and uploads the file to the Files API. + """ + output_lines = [json.dumps(result) for result in results] + + with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer: + file_buffer.filename = f"{batch_id}_{file_type}.jsonl" + uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) + return uploaded_file.id diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py new file mode 100644 index 000000000..d8d06868b --- /dev/null +++ b/llama_stack/providers/inline/batches/reference/config.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pydantic import BaseModel, Field + +from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig + + +class ReferenceBatchesImplConfig(BaseModel): + """Configuration for the Reference Batches implementation.""" + + kvstore: KVStoreConfig = Field( + description="Configuration for the key-value store backend.", + ) + + max_concurrent_batches: int = Field( + default=1, + description="Maximum number of concurrent batches to process simultaneously.", + ge=1, + ) + + max_concurrent_requests_per_batch: int = Field( + default=10, + description="Maximum number of concurrent requests to process per batch.", + ge=1, + ) + + # TODO: add a max requests per second rate limiter + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> dict: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="batches.db", + ), + } diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py new file mode 100644 index 000000000..de7886efb --- /dev/null +++ b/llama_stack/providers/registry/batches.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec + + +def available_providers() -> list[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.batches, + provider_type="inline::reference", + pip_packages=["openai"], + module="llama_stack.providers.inline.batches.reference", + config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig", + api_dependencies=[ + Api.inference, + Api.files, + Api.models, + ], + description="Reference implementation of batches API with KVStore persistence.", + ), + ] diff --git a/tests/integration/batches/__init__.py b/tests/integration/batches/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/integration/batches/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/integration/batches/conftest.py b/tests/integration/batches/conftest.py new file mode 100644 index 000000000..974fe77ab --- /dev/null +++ b/tests/integration/batches/conftest.py @@ -0,0 +1,122 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Shared pytest fixtures for batch tests.""" + +import json +import time +import warnings +from contextlib import contextmanager +from io import BytesIO + +import pytest + +from llama_stack.apis.files import OpenAIFilePurpose + + +class BatchHelper: + """Helper class for creating and managing batch input files.""" + + def __init__(self, client): + """Initialize with either a batch_client or openai_client.""" + self.client = client + + @contextmanager + def create_file(self, content: str | list[dict], filename_prefix="batch_input"): + """Context manager for creating and cleaning up batch input files. + + Args: + content: Either a list of batch request dictionaries or raw string content + filename_prefix: Prefix for the generated filename (or full filename if content is string) + + Yields: + The uploaded file object + """ + if isinstance(content, str): + # Handle raw string content (e.g., malformed JSONL, empty files) + file_content = content.encode("utf-8") + else: + # Handle list of batch request dictionaries + jsonl_content = "\n".join(json.dumps(req) for req in content) + file_content = jsonl_content.encode("utf-8") + + filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl" + + with BytesIO(file_content) as file_buffer: + file_buffer.name = filename + uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) + + try: + yield uploaded_file + finally: + try: + self.client.files.delete(uploaded_file.id) + except Exception: + warnings.warn( + f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}", + stacklevel=2, + ) + + def wait_for( + self, + batch_id: str, + max_wait_time: int = 60, + sleep_interval: int | None = None, + expected_statuses: set[str] | None = None, + timeout_action: str = "fail", + ): + """Wait for a batch to reach a terminal status. + + Args: + batch_id: The batch ID to monitor + max_wait_time: Maximum time to wait in seconds (default: 60 seconds) + sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s) + expected_statuses: Set of expected terminal statuses (default: {"completed"}) + timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip) + + Returns: + The final batch object + + Raises: + pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail" + pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status + """ + if sleep_interval is None: + # Default to 1/10th of max_wait_time, with min 1s and max 15s + sleep_interval = max(1, min(15, max_wait_time // 10)) + + if expected_statuses is None: + expected_statuses = {"completed"} + + terminal_statuses = {"completed", "failed", "cancelled", "expired"} + unexpected_statuses = terminal_statuses - expected_statuses + + start_time = time.time() + while time.time() - start_time < max_wait_time: + current_batch = self.client.batches.retrieve(batch_id) + + if current_batch.status in expected_statuses: + return current_batch + elif current_batch.status in unexpected_statuses: + error_msg = f"Batch reached unexpected status: {current_batch.status}" + if timeout_action == "skip": + pytest.skip(error_msg) + else: + pytest.fail(error_msg) + + time.sleep(sleep_interval) + + timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds" + if timeout_action == "skip": + pytest.skip(timeout_msg) + else: + pytest.fail(timeout_msg) + + +@pytest.fixture +def batch_helper(openai_client): + """Fixture that provides a BatchHelper instance for OpenAI client.""" + return BatchHelper(openai_client) diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py new file mode 100644 index 000000000..1ef3202d0 --- /dev/null +++ b/tests/integration/batches/test_batches.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Integration tests for the Llama Stack batch processing functionality. + +This module contains comprehensive integration tests for the batch processing API, +using the OpenAI-compatible client interface for consistency. + +Test Categories: + 1. Core Batch Operations: + - test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval + - test_batch_listing: Basic batch listing functionality + - test_batch_immediate_cancellation: Batch cancellation workflow + # TODO: cancel during processing + + 2. End-to-End Processing: + - test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation + +Note: Error conditions and edge cases are primarily tested in test_batches_errors.py +for better organization and separation of concerns. + +CLEANUP WARNING: These tests currently create batches that are not automatically +cleaned up after test completion. This may lead to resource accumulation over +multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch. +The test_batch_e2e_chat_completions test does clean up its output and error files. +""" + +import json + + +class TestBatchesIntegration: + """Integration tests for the batches API.""" + + def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id): + """Test comprehensive batch creation and retrieval scenarios.""" + test_metadata = { + "test_type": "comprehensive", + "purpose": "creation_and_retrieval_test", + "version": "1.0", + "tags": "test,batch", + } + + batch_requests = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata=test_metadata, + ) + + assert batch.endpoint == "/v1/chat/completions" + assert batch.input_file_id == uploaded_file.id + assert batch.completion_window == "24h" + assert batch.metadata == test_metadata + + retrieved_batch = openai_client.batches.retrieve(batch.id) + + assert retrieved_batch.id == batch.id + assert retrieved_batch.object == batch.object + assert retrieved_batch.endpoint == batch.endpoint + assert retrieved_batch.input_file_id == batch.input_file_id + assert retrieved_batch.completion_window == batch.completion_window + assert retrieved_batch.metadata == batch.metadata + + def test_batch_listing(self, openai_client, batch_helper, text_model_id): + """ + Test batch listing. + + This test creates multiple batches and verifies that they can be listed. + It also deletes the input files before execution, which means the batches + will appear as failed due to missing input files. This is expected and + a good thing, because it means no inference is performed. + """ + batch_ids = [] + + for i in range(2): + batch_requests = [ + { + "custom_id": f"request-{i}", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": f"Hello {i}"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + batch_ids.append(batch.id) + + batch_list = openai_client.batches.list() + + assert isinstance(batch_list.data, list) + + listed_batch_ids = {b.id for b in batch_list.data} + for batch_id in batch_ids: + assert batch_id in listed_batch_ids + + def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id): + """Test immediate batch cancellation.""" + batch_requests = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + # hopefully cancel the batch before it completes + cancelling_batch = openai_client.batches.cancel(batch.id) + assert cancelling_batch.status in ["cancelling", "cancelled"] + assert isinstance(cancelling_batch.cancelling_at, int), ( + f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}" + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min + expected_statuses={"cancelled"}, + timeout_action="skip", + ) + + assert final_batch.status == "cancelled" + assert isinstance(final_batch.cancelled_at, int), ( + f"cancelled_at should be int, got {type(final_batch.cancelled_at)}" + ) + + def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id): + """Test end-to-end batch processing for chat completions with both successful and failed operations.""" + batch_requests = [ + { + "custom_id": "success-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Say hello"}], + "max_tokens": 20, + }, + }, + { + "custom_id": "error-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "This should fail"}], + "max_tokens": -1, # Invalid negative max_tokens will cause inference error + }, + }, + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"test": "e2e_success_and_errors_test"}, + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, # often takes 2-3 minutes + expected_statuses={"completed"}, + timeout_action="skip", + ) + + # Expecting a completed batch with both successful and failed requests + # Batch(id='batch_xxx', + # completion_window='24h', + # created_at=..., + # endpoint='/v1/chat/completions', + # input_file_id='file-xxx', + # object='batch', + # status='completed', + # output_file_id='file-xxx', + # error_file_id='file-xxx', + # request_counts=BatchRequestCounts(completed=1, failed=1, total=2)) + + assert final_batch.status == "completed" + assert final_batch.request_counts is not None + assert final_batch.request_counts.total == 2 + assert final_batch.request_counts.completed == 1 + assert final_batch.request_counts.failed == 1 + + assert final_batch.output_file_id is not None, "Output file should exist for successful requests" + + output_content = openai_client.files.content(final_batch.output_file_id) + if isinstance(output_content, str): + output_text = output_content + else: + output_text = output_content.content.decode("utf-8") + + output_lines = output_text.strip().split("\n") + + for line in output_lines: + result = json.loads(line) + + assert "id" in result + assert "custom_id" in result + assert result["custom_id"] == "success-1" + + assert "response" in result + + assert result["response"]["status_code"] == 200 + assert "body" in result["response"] + assert "choices" in result["response"]["body"] + + assert final_batch.error_file_id is not None, "Error file should exist for failed requests" + + error_content = openai_client.files.content(final_batch.error_file_id) + if isinstance(error_content, str): + error_text = error_content + else: + error_text = error_content.content.decode("utf-8") + + error_lines = error_text.strip().split("\n") + + for line in error_lines: + result = json.loads(line) + + assert "id" in result + assert "custom_id" in result + assert result["custom_id"] == "error-1" + assert "error" in result + error = result["error"] + assert error is not None + assert "code" in error or "message" in error, "Error should have code or message" + + deleted_output_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully" + + deleted_error_file = openai_client.files.delete(final_batch.error_file_id) + assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py new file mode 100644 index 000000000..2cd1e561e --- /dev/null +++ b/tests/integration/batches/test_batches_errors.py @@ -0,0 +1,694 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Error handling and edge case tests for the Llama Stack batch processing functionality. + +This module focuses exclusively on testing error conditions, validation failures, +and edge cases for batch operations to ensure robust error handling and graceful +degradation. + +Test Categories: + 1. File and Input Validation: + - test_batch_nonexistent_file_id: Handling invalid file IDs + - test_batch_malformed_jsonl: Processing malformed JSONL input files + - test_file_malformed_batch_file: Handling malformed files at upload time + - test_batch_missing_required_fields: Validation of required request fields + + 2. API Endpoint and Model Validation: + - test_batch_invalid_endpoint: Invalid endpoint handling during creation + - test_batch_error_handling_invalid_model: Error handling with nonexistent models + - test_batch_endpoint_mismatch: Validation of endpoint/URL consistency + + 3. Batch Lifecycle Error Handling: + - test_batch_retrieve_nonexistent: Retrieving non-existent batches + - test_batch_cancel_nonexistent: Cancelling non-existent batches + - test_batch_cancel_completed: Attempting to cancel completed batches + + 4. Parameter and Configuration Validation: + - test_batch_invalid_completion_window: Invalid completion window values + - test_batch_invalid_metadata_types: Invalid metadata type validation + - test_batch_missing_required_body_fields: Validation of required fields in request body + + 5. Feature Restriction and Compatibility: + - test_batch_streaming_not_supported: Streaming request rejection + - test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation + +Note: Core functionality and OpenAI compatibility tests are located in +test_batches_integration.py for better organization and separation of concerns. + +CLEANUP WARNING: These tests create batches to test error conditions but do not +automatically clean them up after test completion. While most error tests create +batches that fail quickly, some may create valid batches that consume resources. +""" + +import pytest +from openai import BadRequestError, ConflictError, NotFoundError + + +class TestBatchesErrorHandling: + """Error handling and edge case tests for the batches API using OpenAI client.""" + + def test_batch_nonexistent_file_id(self, openai_client, batch_helper): + """Test batch creation with nonexistent input file ID.""" + + batch = openai_client.batches.create( + input_file_id="file-nonexistent-xyz", + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='invalid_request', + # line=None, + # message='Cannot find file ..., or organization ... does not have access to it.', + # param='file_id') + # ], object='list'), + # failed_at=1754566971, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.code == "invalid_request" + assert "cannot find file" in error.message.lower() + + def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id): + """Test batch creation with invalid endpoint.""" + batch_requests = [ + { + "custom_id": "invalid-endpoint", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + with pytest.raises(BadRequestError) as exc_info: + openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/invalid/endpoint", + completion_window="24h", + ) + + # Expected - + # Error code: 400 - { + # 'error': { + # 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.", + # 'type': 'invalid_request_error', + # 'param': 'endpoint', + # 'code': 'invalid_value' + # } + # } + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 400 + assert "invalid value" in error_msg + assert "/v1/invalid/endpoint" in error_msg + assert "supported values" in error_msg + assert "endpoint" in error_msg + assert "invalid_value" in error_msg + + def test_batch_malformed_jsonl(self, openai_client, batch_helper): + """ + Test batch with malformed JSONL input. + + The /v1/files endpoint requires valid JSONL format, so we provide a well formed line + before a malformed line to ensure we get to the /v1/batches validation stage. + """ + with batch_helper.create_file( + """{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}} +{invalid json here""", + "malformed_batch_input.jsonl", + ) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # ..., + # BatchError(code='invalid_json_line', + # line=2, + # message='This line is not parseable as valid JSON.', + # param=None) + # ], object='list'), + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) > 0 + error = final_batch.errors.data[-1] # get last error because first may be about the "test" model + assert error.code == "invalid_json_line" + assert error.line == 2 + assert "not" in error.message.lower() + assert "valid json" in error.message.lower() + + @pytest.mark.xfail(reason="Not all file providers validate content") + @pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"]) + def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests): + """Test file upload with malformed content.""" + + with pytest.raises(BadRequestError) as exc_info: + with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"): + # /v1/files rejects the file, we don't get to batch creation + pass + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 400 + assert "invalid file format" in error_msg + assert "jsonl" in error_msg + + def test_batch_retrieve_nonexistent(self, openai_client): + """Test retrieving nonexistent batch.""" + with pytest.raises(NotFoundError) as exc_info: + openai_client.batches.retrieve("batch-nonexistent-xyz") + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 404 + assert "no batch found" in error_msg or "not found" in error_msg + + def test_batch_cancel_nonexistent(self, openai_client): + """Test cancelling nonexistent batch.""" + with pytest.raises(NotFoundError) as exc_info: + openai_client.batches.cancel("batch-nonexistent-xyz") + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 404 + assert "no batch found" in error_msg or "not found" in error_msg + + def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id): + """Test cancelling already completed batch.""" + batch_requests = [ + { + "custom_id": "cancel-completed", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Quick test"}], + "max_tokens": 5, + }, + } + ] + + with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, # often take 10-11 min, give it 3 min + expected_statuses={"completed"}, + timeout_action="skip", + ) + + deleted_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully" + + with pytest.raises(ConflictError) as exc_info: + openai_client.batches.cancel(batch.id) + + # Expecting - + # Error code: 409 - { + # 'error': { + # 'message': "Cannot cancel a batch with status 'completed'.", + # 'type': 'invalid_request_error', + # 'param': None, + # 'code': None + # } + # } + # + # NOTE: Same for "failed", cancelling "cancelled" batches is allowed + + error_msg = str(exc_info.value).lower() + assert exc_info.value.status_code == 409 + assert "cannot cancel" in error_msg + + def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id): + """Test batch with requests missing required fields.""" + batch_requests = [ + { + # Missing custom_id + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "No custom_id"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "no-method", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "No method"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "no-url", + "method": "POST", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "No URL"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "no-body", + "method": "POST", + "url": "/v1/chat/completions", + }, + ] + + with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors( + # data=[ + # BatchError( + # code='missing_required_parameter', + # line=1, + # message="Missing required parameter: 'custom_id'.", + # param='custom_id' + # ), + # BatchError( + # code='missing_required_parameter', + # line=2, + # message="Missing required parameter: 'method'.", + # param='method' + # ), + # BatchError( + # code='missing_required_parameter', + # line=3, + # message="Missing required parameter: 'url'.", + # param='url' + # ), + # BatchError( + # code='missing_required_parameter', + # line=4, + # message="Missing required parameter: 'body'.", + # param='body' + # ) + # ], object='list'), + # failed_at=1754566945, + # ...) + # ) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 4 + no_custom_id_error = final_batch.errors.data[0] + assert no_custom_id_error.code == "missing_required_parameter" + assert no_custom_id_error.line == 1 + assert "missing" in no_custom_id_error.message.lower() + assert "custom_id" in no_custom_id_error.message.lower() + no_method_error = final_batch.errors.data[1] + assert no_method_error.code == "missing_required_parameter" + assert no_method_error.line == 2 + assert "missing" in no_method_error.message.lower() + assert "method" in no_method_error.message.lower() + no_url_error = final_batch.errors.data[2] + assert no_url_error.code == "missing_required_parameter" + assert no_url_error.line == 3 + assert "missing" in no_url_error.message.lower() + assert "url" in no_url_error.message.lower() + no_body_error = final_batch.errors.data[3] + assert no_body_error.code == "missing_required_parameter" + assert no_body_error.line == 4 + assert "missing" in no_body_error.message.lower() + assert "body" in no_body_error.message.lower() + + def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id): + """Test batch creation with invalid completion window.""" + batch_requests = [ + { + "custom_id": "invalid-completion-window", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + for window in ["1h", "48h", "invalid", ""]: + with pytest.raises(BadRequestError) as exc_info: + openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window=window, + ) + assert exc_info.value.status_code == 400 + error_msg = str(exc_info.value).lower() + assert "invalid value" in error_msg + assert "completion_window" in error_msg + assert "supported values are" in error_msg + + def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id): + """Test that streaming responses are not supported in batches.""" + batch_requests = [ + { + "custom_id": "streaming-test", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + "stream": True, # Not supported + }, + } + ] + + with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError(code='streaming_unsupported', + # line=1, + # message='Chat Completions: Streaming is not supported in the Batch API.', + # param='body.stream') + # ], object='list'), + # failed_at=1754566965, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.code == "streaming_unsupported" + assert error.line == 1 + assert "streaming" in error.message.lower() + assert "not supported" in error.message.lower() + assert error.param == "body.stream" + assert final_batch.failed_at is not None + + def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id): + """ + Test batch with mixed streaming and non-streaming requests. + + This is distinct from test_batch_streaming_not_supported, which tests a single + streaming request, to ensure an otherwise valid batch fails when a single + streaming request is included. + """ + batch_requests = [ + { + "custom_id": "valid-non-streaming-request", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello without streaming"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "streaming-request", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello with streaming"}], + "max_tokens": 10, + "stream": True, # Not supported + }, + }, + ] + + with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='streaming_unsupported', + # line=2, + # message='Chat Completions: Streaming is not supported in the Batch API.', + # param='body.stream') + # ], object='list'), + # failed_at=1754574442, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.code == "streaming_unsupported" + assert error.line == 2 + assert "streaming" in error.message.lower() + assert "not supported" in error.message.lower() + assert error.param == "body.stream" + assert final_batch.failed_at is not None + + def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id): + """Test batch creation with mismatched endpoint and request URL.""" + batch_requests = [ + { + "custom_id": "endpoint-mismatch", + "method": "POST", + "url": "/v1/embeddings", # Different from batch endpoint + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + }, + } + ] + + with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", # Different from request URL + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='invalid_url', + # line=1, + # message='The URL provided for this request does not match the batch endpoint.', + # param='url') + # ], object='list'), + # failed_at=1754566972, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.line == 1 + assert error.code == "invalid_url" + assert "does not match" in error.message.lower() + assert "endpoint" in error.message.lower() + assert final_batch.failed_at is not None + + def test_batch_error_handling_invalid_model(self, openai_client, batch_helper): + """Test batch error handling with invalid model.""" + batch_requests = [ + { + "custom_id": "invalid-model", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "nonexistent-model-xyz", + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError(code='model_not_found', + # line=1, + # message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.", + # param='body.model') + # ], object='list'), + # failed_at=1754566978, + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 1 + error = final_batch.errors.data[0] + assert error.line == 1 + assert error.code == "model_not_found" + assert "not supported" in error.message.lower() + assert error.param == "body.model" + assert final_batch.failed_at is not None + + def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id): + """Test batch with requests missing required fields in body (model and messages).""" + batch_requests = [ + { + "custom_id": "missing-model", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + # Missing model field + "messages": [{"role": "user", "content": "Hello without model"}], + "max_tokens": 10, + }, + }, + { + "custom_id": "missing-messages", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + # Missing messages field + "max_tokens": 10, + }, + }, + ] + + with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + ) + + final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) + + # Expecting - + # Batch(..., + # status='failed', + # errors=Errors(data=[ + # BatchError( + # code='invalid_request', + # line=1, + # message='Model parameter is required.', + # param='body.model'), + # BatchError( + # code='invalid_request', + # line=2, + # message='Messages parameter is required.', + # param='body.messages') + # ], object='list'), + # ...) + + assert final_batch.status == "failed" + assert final_batch.errors is not None + assert len(final_batch.errors.data) == 2 + + model_error = final_batch.errors.data[0] + assert model_error.line == 1 + assert "model" in model_error.message.lower() + assert model_error.param == "body.model" + + messages_error = final_batch.errors.data[1] + assert messages_error.line == 2 + assert "messages" in messages_error.message.lower() + assert messages_error.param == "body.messages" + + assert final_batch.failed_at is not None + + def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id): + """Test batch creation with invalid metadata types (like lists).""" + batch_requests = [ + { + "custom_id": "invalid-metadata-type", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": text_model_id, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 10, + }, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + with pytest.raises(Exception) as exc_info: + openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={ + "tags": ["tag1", "tag2"], # Invalid type, should be a string + }, + ) + + # Expecting - + # Error code: 400 - {'error': + # {'message': "Invalid type for 'metadata.tags': expected a string, + # but got an array instead.", + # 'type': 'invalid_request_error', 'param': 'metadata.tags', + # 'code': 'invalid_type'}} + + error_msg = str(exc_info.value).lower() + assert "400" in error_msg + assert "tags" in error_msg + assert "string" in error_msg diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py new file mode 100644 index 000000000..9fe0cc710 --- /dev/null +++ b/tests/unit/providers/batches/test_reference.py @@ -0,0 +1,753 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test suite for the reference implementation of the Batches API. + +The tests are categorized and outlined below, keep this updated: + +- Batch creation with various parameters and validation: + * test_create_and_retrieve_batch_success (positive) + * test_create_batch_without_metadata (positive) + * test_create_batch_completion_window (negative) + * test_create_batch_invalid_endpoints (negative) + * test_create_batch_invalid_metadata (negative) + +- Batch retrieval and error handling for non-existent batches: + * test_retrieve_batch_not_found (negative) + +- Batch cancellation with proper status transitions: + * test_cancel_batch_success (positive) + * test_cancel_batch_invalid_statuses (negative) + * test_cancel_batch_not_found (negative) + +- Batch listing with pagination and filtering: + * test_list_batches_empty (positive) + * test_list_batches_single_batch (positive) + * test_list_batches_multiple_batches (positive) + * test_list_batches_with_limit (positive) + * test_list_batches_with_pagination (positive) + * test_list_batches_invalid_after (negative) + +- Data persistence in the underlying key-value store: + * test_kvstore_persistence (positive) + +- Batch processing concurrency control: + * test_max_concurrent_batches (positive) + +- Input validation testing (direct _validate_input method tests): + * test_validate_input_file_not_found (negative) + * test_validate_input_file_exists_empty_content (positive) + * test_validate_input_file_mixed_valid_invalid_json (mixed) + * test_validate_input_invalid_model (negative) + * test_validate_input_url_mismatch (negative) + * test_validate_input_multiple_errors_per_request (negative) + * test_validate_input_invalid_request_format (negative) + * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) + * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) + +The tests use temporary SQLite databases for isolation and mock external +dependencies like inference, files, and models APIs. +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from llama_stack.apis.batches import BatchObject +from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError +from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl +from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig +from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig + + +class TestReferenceBatchesImpl: + """Test the reference implementation of the Batches API.""" + + @pytest.fixture + async def provider(self): + """Create a test provider instance with temporary database.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test_batches.db" + kvstore_config = SqliteKVStoreConfig(db_path=str(db_path)) + config = ReferenceBatchesImplConfig(kvstore=kvstore_config) + + # Create kvstore and mock APIs + from unittest.mock import AsyncMock + + from llama_stack.providers.utils.kvstore import kvstore_impl + + kvstore = await kvstore_impl(config.kvstore) + mock_inference = AsyncMock() + mock_files = AsyncMock() + mock_models = AsyncMock() + + provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore) + await provider.initialize() + + # unit tests should not require background processing + provider.process_batches = False + + yield provider + + await provider.shutdown() + + @pytest.fixture + def sample_batch_data(self): + """Sample batch data for testing.""" + return { + "input_file_id": "file_abc123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h", + "metadata": {"test": "true", "priority": "high"}, + } + + def _validate_batch_type(self, batch, expected_metadata=None): + """ + Helper function to validate batch object structure and field types. + + Note: This validates the direct BatchObject from the provider, not the + client library response which has a different structure. + + Args: + batch: The BatchObject instance to validate. + expected_metadata: Optional expected metadata dictionary to validate against. + """ + assert isinstance(batch.id, str) + assert isinstance(batch.completion_window, str) + assert isinstance(batch.created_at, int) + assert isinstance(batch.endpoint, str) + assert isinstance(batch.input_file_id, str) + assert batch.object == "batch" + assert batch.status in [ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled", + ] + + if expected_metadata is not None: + assert batch.metadata == expected_metadata + + timestamp_fields = [ + "cancelled_at", + "cancelling_at", + "completed_at", + "expired_at", + "expires_at", + "failed_at", + "finalizing_at", + "in_progress_at", + ] + for field in timestamp_fields: + field_value = getattr(batch, field, None) + if field_value is not None: + assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}" + + file_id_fields = ["error_file_id", "output_file_id"] + for field in file_id_fields: + field_value = getattr(batch, field, None) + if field_value is not None: + assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}" + + if hasattr(batch, "request_counts") and batch.request_counts is not None: + assert isinstance(batch.request_counts.completed, int), ( + f"request_counts.completed should be int, got {type(batch.request_counts.completed)}" + ) + assert isinstance(batch.request_counts.failed, int), ( + f"request_counts.failed should be int, got {type(batch.request_counts.failed)}" + ) + assert isinstance(batch.request_counts.total, int), ( + f"request_counts.total should be int, got {type(batch.request_counts.total)}" + ) + + if hasattr(batch, "errors") and batch.errors is not None: + assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}" + + if hasattr(batch.errors, "data") and batch.errors.data is not None: + assert isinstance(batch.errors.data, list), ( + f"errors.data should be list or None, got {type(batch.errors.data)}" + ) + + for i, error_item in enumerate(batch.errors.data): + assert isinstance(error_item, dict), ( + f"errors.data[{i}] should be object or dict, got {type(error_item)}" + ) + + if hasattr(error_item, "code") and error_item.code is not None: + assert isinstance(error_item.code, str), ( + f"errors.data[{i}].code should be str or None, got {type(error_item.code)}" + ) + + if hasattr(error_item, "line") and error_item.line is not None: + assert isinstance(error_item.line, int), ( + f"errors.data[{i}].line should be int or None, got {type(error_item.line)}" + ) + + if hasattr(error_item, "message") and error_item.message is not None: + assert isinstance(error_item.message, str), ( + f"errors.data[{i}].message should be str or None, got {type(error_item.message)}" + ) + + if hasattr(error_item, "param") and error_item.param is not None: + assert isinstance(error_item.param, str), ( + f"errors.data[{i}].param should be str or None, got {type(error_item.param)}" + ) + + if hasattr(batch.errors, "object") and batch.errors.object is not None: + assert isinstance(batch.errors.object, str), ( + f"errors.object should be str or None, got {type(batch.errors.object)}" + ) + assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}" + + async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): + """Test successful batch creation and retrieval.""" + created_batch = await provider.create_batch(**sample_batch_data) + + self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) + + assert created_batch.id.startswith("batch_") + assert len(created_batch.id) > 13 + assert created_batch.object == "batch" + assert created_batch.endpoint == sample_batch_data["endpoint"] + assert created_batch.input_file_id == sample_batch_data["input_file_id"] + assert created_batch.completion_window == sample_batch_data["completion_window"] + assert created_batch.status == "validating" + assert created_batch.metadata == sample_batch_data["metadata"] + assert isinstance(created_batch.created_at, int) + assert created_batch.created_at > 0 + + retrieved_batch = await provider.retrieve_batch(created_batch.id) + + self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) + + assert retrieved_batch.id == created_batch.id + assert retrieved_batch.input_file_id == created_batch.input_file_id + assert retrieved_batch.endpoint == created_batch.endpoint + assert retrieved_batch.status == created_batch.status + assert retrieved_batch.metadata == created_batch.metadata + + async def test_create_batch_without_metadata(self, provider): + """Test batch creation without optional metadata.""" + batch = await provider.create_batch( + input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" + ) + + assert batch.metadata is None + + async def test_create_batch_completion_window(self, provider): + """Test batch creation with invalid completion window.""" + with pytest.raises(ValueError, match="Invalid completion_window"): + await provider.create_batch( + input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now" + ) + + @pytest.mark.parametrize( + "endpoint", + [ + "/v1/embeddings", + "/v1/completions", + "/v1/invalid/endpoint", + "", + ], + ) + async def test_create_batch_invalid_endpoints(self, provider, endpoint): + """Test batch creation with various invalid endpoints.""" + with pytest.raises(ValueError, match="Invalid endpoint"): + await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + + async def test_create_batch_invalid_metadata(self, provider): + """Test that batch creation fails with invalid metadata.""" + with pytest.raises(ValueError, match="should be a valid string"): + await provider.create_batch( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={123: "invalid_key"}, # Non-string key + ) + + with pytest.raises(ValueError, match="should be a valid string"): + await provider.create_batch( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"valid_key": 456}, # Non-string value + ) + + async def test_retrieve_batch_not_found(self, provider): + """Test error when retrieving non-existent batch.""" + with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): + await provider.retrieve_batch("nonexistent_batch") + + async def test_cancel_batch_success(self, provider, sample_batch_data): + """Test successful batch cancellation.""" + created_batch = await provider.create_batch(**sample_batch_data) + assert created_batch.status == "validating" + + cancelled_batch = await provider.cancel_batch(created_batch.id) + + assert cancelled_batch.id == created_batch.id + assert cancelled_batch.status in ["cancelling", "cancelled"] + assert isinstance(cancelled_batch.cancelling_at, int) + assert cancelled_batch.cancelling_at >= created_batch.created_at + + @pytest.mark.parametrize("status", ["failed", "expired", "completed"]) + async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): + """Test error when cancelling batch in final states.""" + provider.process_batches = False + created_batch = await provider.create_batch(**sample_batch_data) + + # directly update status in kvstore + await provider._update_batch(created_batch.id, status=status) + + with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): + await provider.cancel_batch(created_batch.id) + + async def test_cancel_batch_not_found(self, provider): + """Test error when cancelling non-existent batch.""" + with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): + await provider.cancel_batch("nonexistent_batch") + + async def test_list_batches_empty(self, provider): + """Test listing batches when none exist.""" + response = await provider.list_batches() + + assert response.object == "list" + assert response.data == [] + assert response.first_id is None + assert response.last_id is None + assert response.has_more is False + + async def test_list_batches_single_batch(self, provider, sample_batch_data): + """Test listing batches with single batch.""" + created_batch = await provider.create_batch(**sample_batch_data) + + response = await provider.list_batches() + + assert len(response.data) == 1 + self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) + assert response.data[0].id == created_batch.id + assert response.first_id == created_batch.id + assert response.last_id == created_batch.id + assert response.has_more is False + + async def test_list_batches_multiple_batches(self, provider): + """Test listing multiple batches.""" + batches = [ + await provider.create_batch( + input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + ) + for i in range(3) + ] + + response = await provider.list_batches() + + assert len(response.data) == 3 + + batch_ids = {batch.id for batch in response.data} + expected_ids = {batch.id for batch in batches} + assert batch_ids == expected_ids + assert response.has_more is False + + assert response.first_id in expected_ids + assert response.last_id in expected_ids + + async def test_list_batches_with_limit(self, provider): + """Test listing batches with limit parameter.""" + batches = [ + await provider.create_batch( + input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + ) + for i in range(3) + ] + + response = await provider.list_batches(limit=2) + + assert len(response.data) == 2 + assert response.has_more is True + assert response.first_id == response.data[0].id + assert response.last_id == response.data[1].id + batch_ids = {batch.id for batch in response.data} + expected_ids = {batch.id for batch in batches} + assert batch_ids.issubset(expected_ids) + + async def test_list_batches_with_pagination(self, provider): + """Test listing batches with pagination using 'after' parameter.""" + for i in range(3): + await provider.create_batch( + input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + ) + + # Get first page + first_page = await provider.list_batches(limit=1) + assert len(first_page.data) == 1 + assert first_page.has_more is True + + # Get second page using 'after' + second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) + assert len(second_page.data) == 1 + assert second_page.data[0].id != first_page.data[0].id + + # Verify we got the next batch in order + all_batches = await provider.list_batches() + expected_second_batch_id = all_batches.data[1].id + assert second_page.data[0].id == expected_second_batch_id + + async def test_list_batches_invalid_after(self, provider, sample_batch_data): + """Test listing batches with invalid 'after' parameter.""" + await provider.create_batch(**sample_batch_data) + + response = await provider.list_batches(after="nonexistent_batch") + + # Should return all batches (no filtering when 'after' batch not found) + assert len(response.data) == 1 + + async def test_kvstore_persistence(self, provider, sample_batch_data): + """Test that batches are properly persisted in kvstore.""" + batch = await provider.create_batch(**sample_batch_data) + + stored_data = await provider.kvstore.get(f"batch:{batch.id}") + assert stored_data is not None + + stored_batch_dict = json.loads(stored_data) + assert stored_batch_dict["id"] == batch.id + assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"] + + async def test_validate_input_file_not_found(self, provider): + """Test _validate_input when input file does not exist.""" + provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found")) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="nonexistent_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + assert errors[0].code == "invalid_request" + assert errors[0].message == "Cannot find file nonexistent_file." + assert errors[0].param == "input_file_id" + assert errors[0].line is None + + async def test_validate_input_file_exists_empty_content(self, provider): + """Test _validate_input when file exists but is empty.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b"" + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="empty_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 0 + assert len(requests) == 0 + + async def test_validate_input_file_mixed_valid_invalid_json(self, provider): + """Test _validate_input when file contains valid and invalid JSON lines.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + # Line 1: valid JSON with proper body args, Line 2: invalid JSON + mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="mixed_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + # Should have 1 JSON parsing error from line 2, and 1 valid request from line 1 + assert len(errors) == 1 + assert len(requests) == 1 + + assert errors[0].code == "invalid_json_line" + assert errors[0].line == 2 + assert errors[0].message == "This line is not parseable as valid JSON." + + assert requests[0].custom_id == "req-1" + assert requests[0].method == "POST" + assert requests[0].url == "/v1/chat/completions" + assert requests[0].body["model"] == "test-model" + assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}] + + async def test_validate_input_invalid_model(self, provider): + """Test _validate_input when file contains request with non-existent model.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found")) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="invalid_model_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "model_not_found" + assert errors[0].line == 1 + assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported" + assert errors[0].param == "body.model" + + @pytest.mark.parametrize( + "param_name,param_path,error_code,error_message", + [ + ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), + ("method", "method", "missing_required_parameter", "Missing required parameter: method"), + ("url", "url", "missing_required_parameter", "Missing required parameter: url"), + ("body", "body", "missing_required_parameter", "Missing required parameter: body"), + ("model", "body.model", "invalid_request", "Model parameter is required"), + ("messages", "body.messages", "invalid_request", "Messages parameter is required"), + ], + ) + async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): + """Test _validate_input when file contains request with missing required parameters.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, + } + + # Remove the specific parameter being tested + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + del base_request[top_level][nested_param] + else: + del base_request[param_name] + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id=f"missing_{param_name}_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == error_code + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + + async def test_validate_input_url_mismatch(self, provider): + """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", # This doesn't match the URL in the request + input_file_id="url_mismatch_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "invalid_url" + assert errors[0].line == 1 + assert errors[0].message == "URL provided for this request does not match the batch endpoint" + assert errors[0].param == "url" + + async def test_validate_input_multiple_errors_per_request(self, provider): + """Test _validate_input when a single request has multiple validation errors.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + # Request missing custom_id, has invalid URL, and missing model in body + mock_response.body = ( + b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}' + ) + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request + input_file_id="multiple_errors_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) >= 2 # At least missing custom_id and URL mismatch + assert len(requests) == 0 + + for error in errors: + assert error.line == 1 + + error_codes = {error.code for error in errors} + assert "missing_required_parameter" in error_codes # missing custom_id + assert "invalid_url" in error_codes # URL mismatch + + async def test_validate_input_invalid_request_format(self, provider): + """Test _validate_input when file contains non-object JSON (array, string, number).""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + mock_response.body = b'["not", "a", "request", "object"]' + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id="invalid_format_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "invalid_request" + assert errors[0].line == 1 + assert errors[0].message == "Each line must be a JSON dictionary object" + + @pytest.mark.parametrize( + "param_name,param_path,invalid_value,error_message", + [ + ("custom_id", "custom_id", 12345, "Custom_id must be a string"), + ("url", "url", 123, "URL must be a string"), + ("method", "method", ["POST"], "Method must be a string"), + ("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"), + ("model", "body.model", 123, "Model must be a string"), + ("messages", "body.messages", "invalid messages format", "Messages must be an array"), + ], + ) + async def test_validate_input_invalid_parameter_types( + self, provider, param_name, param_path, invalid_value, error_message + ): + """Test _validate_input when file contains request with parameters that have invalid types.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, + } + + # Override the specific parameter with invalid value + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + base_request[top_level][nested_param] = invalid_value + else: + base_request[param_name] = invalid_value + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/chat/completions", + input_file_id=f"invalid_{param_name}_type_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == "invalid_request" + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + + async def test_max_concurrent_batches(self, provider): + """Test max_concurrent_batches configuration and concurrency control.""" + import asyncio + + provider._batch_semaphore = asyncio.Semaphore(2) + + provider.process_batches = True # enable because we're testing background processing + + active_batches = 0 + + async def add_and_wait(batch_id: str): + nonlocal active_batches + active_batches += 1 + await asyncio.sleep(float("inf")) + + # the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl, + # so we can replace _process_batch_impl with our mock to control concurrency + provider._process_batch_impl = add_and_wait + + for _ in range(3): + await provider.create_batch( + input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" + ) + + await asyncio.sleep(0.042) # let tasks start + + assert active_batches == 2, f"Expected 2 active batches, got {active_batches}" From 04a73c89efbc9160e2a246215ae43af489e7f74c Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 13 Aug 2025 07:15:07 -0400 Subject: [PATCH 2/3] add notes about batches development status to docs this also captures other notes from agents, eval and inference apis --- docs/source/concepts/apis.md | 1 + docs/source/providers/agents/index.md | 9 +++++++++ docs/source/providers/batches/index.md | 8 ++++++++ docs/source/providers/eval/index.md | 2 ++ docs/source/providers/inference/index.md | 6 ++++++ llama_stack/apis/batches/batches.py | 9 ++++++++- scripts/provider_codegen.py | 22 ++++++++++++++++++++++ 7 files changed, 56 insertions(+), 1 deletion(-) diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md index 5a10d6498..f8f73a928 100644 --- a/docs/source/concepts/apis.md +++ b/docs/source/concepts/apis.md @@ -18,3 +18,4 @@ We are working on adding a few more APIs to complete the application lifecycle. - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs - **Synthetic Data Generation**: generate synthetic data for model development +- **Batches**: OpenAI-compatible batch management for inference diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index 92bf9edc0..a2c48d4b9 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -2,6 +2,15 @@ ## Overview +Agents API for creating and interacting with agentic systems. + + Main functionalities provided by this API: + - Create agents with specific instructions and ability to use tools. + - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". + - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). + - Agents can be provided with various shields (see the Safety API for more details). + - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. + This section contains documentation for all available providers for the **agents** API. ## Providers diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md index d2405ecf7..2a39a626c 100644 --- a/docs/source/providers/batches/index.md +++ b/docs/source/providers/batches/index.md @@ -2,6 +2,14 @@ ## Overview +Protocol for batch processing API operations. + + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + Note: This API is currently under active development and may undergo changes. + This section contains documentation for all available providers for the **batches** API. ## Providers diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md index d180d256c..a14fada1d 100644 --- a/docs/source/providers/eval/index.md +++ b/docs/source/providers/eval/index.md @@ -2,6 +2,8 @@ ## Overview +Llama Stack Evaluation API for running evaluations on model and agent candidates. + This section contains documentation for all available providers for the **eval** API. ## Providers diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 1c7bc86b9..cdde3a18a 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -2,6 +2,12 @@ ## Overview +Llama Stack Inference API for generating completions, chat completions, and embeddings. + + This API provides the raw interface to the underlying models. Two kinds of models are supported: + - LLM models: these models generate "raw" and "chat" (conversational) completions. + - Embedding models: these models generate embeddings to be used for semantic search. + This section contains documentation for all available providers for the **inference** API. ## Providers diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py index 72742d4fa..81ab44ccd 100644 --- a/llama_stack/apis/batches/batches.py +++ b/llama_stack/apis/batches/batches.py @@ -39,7 +39,14 @@ class ListBatchesResponse(BaseModel): @runtime_checkable class Batches(Protocol): - """Protocol for batch processing API operations.""" + """Protocol for batch processing API operations. + + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + Note: This API is currently under active development and may undergo changes. + """ @webmethod(route="/openai/v1/batches", method="POST") async def create_batch( diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index 84c45fe27..beaeeae38 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -18,6 +18,23 @@ from llama_stack.core.distribution import get_provider_registry REPO_ROOT = Path(__file__).parent.parent +def get_api_docstring(api_name: str) -> str | None: + """Extract docstring from the API protocol class.""" + try: + # Import the API module dynamically + api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]) + + # Get the main protocol class (usually capitalized API name) + protocol_class_name = api_name.title() + if hasattr(api_module, protocol_class_name): + protocol_class = getattr(api_module, protocol_class_name) + return protocol_class.__doc__ + except (ImportError, AttributeError): + pass + + return None + + class ChangedPathTracker: """Track a list of paths we may have changed.""" @@ -261,6 +278,11 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N index_content.append(f"# {api_name.title()}\n") index_content.append("## Overview\n") + api_docstring = get_api_docstring(api_name) + if api_docstring: + cleaned_docstring = api_docstring.strip() + index_content.append(f"{cleaned_docstring}\n") + index_content.append( f"This section contains documentation for all available providers for the **{api_name}** API.\n" ) From 44263ce9549da7dc93bdbc36075a7b89d833753e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 14 Aug 2025 09:25:22 -0400 Subject: [PATCH 3/3] remove unused CreateBatchRequest, update completion_window to be literal "24h" --- llama_stack/apis/batches/__init__.py | 4 ++-- llama_stack/apis/batches/batches.py | 14 ++------------ .../providers/inline/batches/reference/batches.py | 4 ++-- tests/integration/batches/test_batches_errors.py | 3 +-- 4 files changed, 7 insertions(+), 18 deletions(-) diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py index d3efe3dba..9ce7d3d75 100644 --- a/llama_stack/apis/batches/__init__.py +++ b/llama_stack/apis/batches/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .batches import Batches, BatchObject, CreateBatchRequest, ListBatchesResponse +from .batches import Batches, BatchObject, ListBatchesResponse -__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesResponse"] +__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py index 81ab44ccd..9297d8597 100644 --- a/llama_stack/apis/batches/batches.py +++ b/llama_stack/apis/batches/batches.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Literal, Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -16,16 +16,6 @@ except ImportError as e: raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e -@json_schema_type -class CreateBatchRequest(BaseModel): - """Request to create a new batch.""" - - input_file_id: str = Field(..., description="The ID of an uploaded file that contains requests for the new batch") - endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch") - completion_window: str = Field(..., description="The time window within which the batch should be processed") - metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata for the batch") - - @json_schema_type class ListBatchesResponse(BaseModel): """Response containing a list of batch objects.""" @@ -53,7 +43,7 @@ class Batches(Protocol): self, input_file_id: str, endpoint: str, - completion_window: str, + completion_window: Literal["24h"], metadata: dict[str, str] | None = None, ) -> BatchObject: """Create a new batch for processing multiple API requests. diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 6e99a00a1..984ef5a90 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -10,7 +10,7 @@ import json import time import uuid from io import BytesIO -from typing import Any +from typing import Any, Literal from openai.types.batch import BatchError, Errors from pydantic import BaseModel @@ -108,7 +108,7 @@ class ReferenceBatchesImpl(Batches): self, input_file_id: str, endpoint: str, - completion_window: str, + completion_window: Literal["24h"], metadata: dict[str, str] | None = None, ) -> BatchObject: """ diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py index 2cd1e561e..bc94a182e 100644 --- a/tests/integration/batches/test_batches_errors.py +++ b/tests/integration/batches/test_batches_errors.py @@ -379,9 +379,8 @@ class TestBatchesErrorHandling: ) assert exc_info.value.status_code == 400 error_msg = str(exc_info.value).lower() - assert "invalid value" in error_msg + assert "error" in error_msg assert "completion_window" in error_msg - assert "supported values are" in error_msg def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id): """Test that streaming responses are not supported in batches."""