diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index f330d2c45..9ef49fba3 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -5,7 +5,7 @@ run-name: Run the integration test suite from tests/integration in replay mode on: push: branches: [ main ] - pull_request_target: + pull_request: branches: [ main ] types: [opened, synchronize, reopened] paths: @@ -34,7 +34,7 @@ on: concurrency: # Skip concurrency for pushes to main - each commit should be tested independently - group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.event.pull_request.number }} + group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }} cancel-in-progress: true jobs: diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index 12957db27..b31709a4f 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -3,7 +3,7 @@ name: Integration Tests (Record) run-name: Run the integration test suite from tests/integration on: - pull_request: + pull_request_target: branches: [ main ] types: [opened, synchronize, labeled] paths: @@ -23,7 +23,7 @@ on: default: 'ollama' concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number }} cancel-in-progress: true jobs: diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b36626719..0549dda21 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14767,8 +14767,7 @@ "OpenAIFilePurpose": { "type": "string", "enum": [ - "assistants", - "batch" + "assistants" ], "title": "OpenAIFilePurpose", "description": "Valid purpose values for OpenAI Files API." @@ -14845,8 +14844,7 @@ "purpose": { "type": "string", "enum": [ - "assistants", - "batch" + "assistants" ], "description": "The intended purpose of the file" } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index e7733b3c3..aa47cd58d 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10951,7 +10951,6 @@ components: type: string enum: - assistants - - batch title: OpenAIFilePurpose description: >- Valid purpose values for OpenAI Files API. @@ -11020,7 +11019,6 @@ components: type: string enum: - assistants - - batch description: The intended purpose of the file additionalProperties: false required: diff --git a/docs/source/concepts/apis.md b/docs/source/concepts/apis.md index f8f73a928..5a10d6498 100644 --- a/docs/source/concepts/apis.md +++ b/docs/source/concepts/apis.md @@ -18,4 +18,3 @@ We are working on adding a few more APIs to complete the application lifecycle. - **Batch Inference**: run inference on a dataset of inputs - **Batch Agents**: run agents on a dataset of inputs - **Synthetic Data Generation**: generate synthetic data for model development -- **Batches**: OpenAI-compatible batch management for inference diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md index a2c48d4b9..92bf9edc0 100644 --- a/docs/source/providers/agents/index.md +++ b/docs/source/providers/agents/index.md @@ -2,15 +2,6 @@ ## Overview -Agents API for creating and interacting with agentic systems. - - Main functionalities provided by this API: - - Create agents with specific instructions and ability to use tools. - - Interactions with agents are grouped into sessions ("threads"), and each interaction is called a "turn". - - Agents can be provided with various tools (see the ToolGroups and ToolRuntime APIs for more details). - - Agents can be provided with various shields (see the Safety API for more details). - - Agents can also use Memory to retrieve information from knowledge bases. See the RAG Tool and Vector IO APIs for more details. - This section contains documentation for all available providers for the **agents** API. ## Providers diff --git a/docs/source/providers/batches/index.md b/docs/source/providers/batches/index.md deleted file mode 100644 index 2a39a626c..000000000 --- a/docs/source/providers/batches/index.md +++ /dev/null @@ -1,21 +0,0 @@ -# Batches - -## Overview - -Protocol for batch processing API operations. - - The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. - - Note: This API is currently under active development and may undergo changes. - -This section contains documentation for all available providers for the **batches** API. - -## Providers - -```{toctree} -:maxdepth: 1 - -inline_reference -``` diff --git a/docs/source/providers/batches/inline_reference.md b/docs/source/providers/batches/inline_reference.md deleted file mode 100644 index a58e5124d..000000000 --- a/docs/source/providers/batches/inline_reference.md +++ /dev/null @@ -1,23 +0,0 @@ -# inline::reference - -## Description - -Reference implementation of batches API with KVStore persistence. - -## Configuration - -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Configuration for the key-value store backend. | -| `max_concurrent_batches` | `` | No | 1 | Maximum number of concurrent batches to process simultaneously. | -| `max_concurrent_requests_per_batch` | `` | No | 10 | Maximum number of concurrent requests to process per batch. | - -## Sample Configuration - -```yaml -kvstore: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/batches.db - -``` - diff --git a/docs/source/providers/eval/index.md b/docs/source/providers/eval/index.md index a14fada1d..d180d256c 100644 --- a/docs/source/providers/eval/index.md +++ b/docs/source/providers/eval/index.md @@ -2,8 +2,6 @@ ## Overview -Llama Stack Evaluation API for running evaluations on model and agent candidates. - This section contains documentation for all available providers for the **eval** API. ## Providers diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index b6d215474..38781e5eb 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -2,12 +2,6 @@ ## Overview -Llama Stack Inference API for generating completions, chat completions, and embeddings. - - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. - This section contains documentation for all available providers for the **inference** API. ## Providers diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py deleted file mode 100644 index 9ce7d3d75..000000000 --- a/llama_stack/apis/batches/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .batches import Batches, BatchObject, ListBatchesResponse - -__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py deleted file mode 100644 index 9297d8597..000000000 --- a/llama_stack/apis/batches/batches.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Literal, Protocol, runtime_checkable - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type, webmethod - -try: - from openai.types import Batch as BatchObject -except ImportError as e: - raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e - - -@json_schema_type -class ListBatchesResponse(BaseModel): - """Response containing a list of batch objects.""" - - object: Literal["list"] = "list" - data: list[BatchObject] = Field(..., description="List of batch objects") - first_id: str | None = Field(default=None, description="ID of the first batch in the list") - last_id: str | None = Field(default=None, description="ID of the last batch in the list") - has_more: bool = Field(default=False, description="Whether there are more batches available") - - -@runtime_checkable -class Batches(Protocol): - """Protocol for batch processing API operations. - - The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. - - Note: This API is currently under active development and may undergo changes. - """ - - @webmethod(route="/openai/v1/batches", method="POST") - async def create_batch( - self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - ) -> BatchObject: - """Create a new batch for processing multiple API requests. - - :param input_file_id: The ID of an uploaded file containing requests for the batch. - :param endpoint: The endpoint to be used for all requests in the batch. - :param completion_window: The time window within which the batch should be processed. - :param metadata: Optional metadata for the batch. - :returns: The created batch object. - """ - ... - - @webmethod(route="/openai/v1/batches/{batch_id}", method="GET") - async def retrieve_batch(self, batch_id: str) -> BatchObject: - """Retrieve information about a specific batch. - - :param batch_id: The ID of the batch to retrieve. - :returns: The batch object. - """ - ... - - @webmethod(route="/openai/v1/batches/{batch_id}/cancel", method="POST") - async def cancel_batch(self, batch_id: str) -> BatchObject: - """Cancel a batch that is in progress. - - :param batch_id: The ID of the batch to cancel. - :returns: The updated batch object. - """ - ... - - @webmethod(route="/openai/v1/batches", method="GET") - async def list_batches( - self, - after: str | None = None, - limit: int = 20, - ) -> ListBatchesResponse: - """List all batches for the current user. - - :param after: A cursor for pagination; returns batches after this batch ID. - :param limit: Number of batches to return (default 20, max 100). - :returns: A list of batch objects. - """ - ... diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 7104d8db6..6e0fa0b3c 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -64,12 +64,6 @@ class SessionNotFoundError(ValueError): super().__init__(message) -class ConflictError(ValueError): - """raised when an operation cannot be performed due to a conflict with the current state""" - - pass - - class ModelTypeError(TypeError): """raised when a model is present but not the correct type""" diff --git a/llama_stack/apis/datatypes.py b/llama_stack/apis/datatypes.py index 87fc95917..cabe46a2f 100644 --- a/llama_stack/apis/datatypes.py +++ b/llama_stack/apis/datatypes.py @@ -86,7 +86,6 @@ class Api(Enum, metaclass=DynamicApiMeta): :cvar inference: Text generation, chat completions, and embeddings :cvar safety: Content moderation and safety shields :cvar agents: Agent orchestration and execution - :cvar batches: Batch processing for asynchronous API requests :cvar vector_io: Vector database operations and queries :cvar datasetio: Dataset input/output operations :cvar scoring: Model output evaluation and scoring @@ -109,7 +108,6 @@ class Api(Enum, metaclass=DynamicApiMeta): inference = "inference" safety = "safety" agents = "agents" - batches = "batches" vector_io = "vector_io" datasetio = "datasetio" scoring = "scoring" diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index a1b9dd4dc..ba8701e23 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -22,7 +22,6 @@ class OpenAIFilePurpose(StrEnum): """ ASSISTANTS = "assistants" - BATCH = "batch" # TODO: Add other purposes as needed diff --git a/llama_stack/core/resolver.py b/llama_stack/core/resolver.py index 7ac98dac8..70c78fb01 100644 --- a/llama_stack/core/resolver.py +++ b/llama_stack/core/resolver.py @@ -8,7 +8,6 @@ import inspect from typing import Any from llama_stack.apis.agents import Agents -from llama_stack.apis.batches import Batches from llama_stack.apis.benchmarks import Benchmarks from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets @@ -76,7 +75,6 @@ def api_protocol_map(external_apis: dict[Api, ExternalApiSpec] | None = None) -> Api.agents: Agents, Api.inference: Inference, Api.inspect: Inspect, - Api.batches: Batches, Api.vector_io: VectorIO, Api.vector_dbs: VectorDBs, Api.models: Models, diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index cbef8ef88..e9d70fc8d 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -32,7 +32,6 @@ from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError -from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.cli.utils import add_config_distro_args, get_config_from_args from llama_stack.core.access_control.access_control import AccessDeniedError @@ -129,10 +128,6 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ] }, ) - elif isinstance(exc, ConflictError): - return HTTPException(status_code=409, detail=str(exc)) - elif isinstance(exc, ResourceNotFoundError): - return HTTPException(status_code=404, detail=str(exc)) elif isinstance(exc, ValueError): return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): diff --git a/llama_stack/providers/inline/batches/__init__.py b/llama_stack/providers/inline/batches/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/llama_stack/providers/inline/batches/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/llama_stack/providers/inline/batches/reference/__init__.py b/llama_stack/providers/inline/batches/reference/__init__.py deleted file mode 100644 index a8ae92eb2..000000000 --- a/llama_stack/providers/inline/batches/reference/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from llama_stack.apis.files import Files -from llama_stack.apis.inference import Inference -from llama_stack.apis.models import Models -from llama_stack.core.datatypes import AccessRule, Api -from llama_stack.providers.utils.kvstore import kvstore_impl - -from .batches import ReferenceBatchesImpl -from .config import ReferenceBatchesImplConfig - -__all__ = ["ReferenceBatchesImpl", "ReferenceBatchesImplConfig"] - - -async def get_provider_impl(config: ReferenceBatchesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): - kvstore = await kvstore_impl(config.kvstore) - inference_api: Inference | None = deps.get(Api.inference) - files_api: Files | None = deps.get(Api.files) - models_api: Models | None = deps.get(Api.models) - - if inference_api is None: - raise ValueError("Inference API is required but not provided in dependencies") - if files_api is None: - raise ValueError("Files API is required but not provided in dependencies") - if models_api is None: - raise ValueError("Models API is required but not provided in dependencies") - - impl = ReferenceBatchesImpl(config, inference_api, files_api, models_api, kvstore) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py deleted file mode 100644 index 984ef5a90..000000000 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ /dev/null @@ -1,553 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import asyncio -import itertools -import json -import time -import uuid -from io import BytesIO -from typing import Any, Literal - -from openai.types.batch import BatchError, Errors -from pydantic import BaseModel - -from llama_stack.apis.batches import Batches, BatchObject, ListBatchesResponse -from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError -from llama_stack.apis.files import Files, OpenAIFilePurpose -from llama_stack.apis.inference import Inference -from llama_stack.apis.models import Models -from llama_stack.log import get_logger -from llama_stack.providers.utils.kvstore import KVStore - -from .config import ReferenceBatchesImplConfig - -BATCH_PREFIX = "batch:" - -logger = get_logger(__name__) - - -class AsyncBytesIO: - """ - Async-compatible BytesIO wrapper to allow async file-like operations. - - We use this when uploading files to the Files API, as it expects an - async file-like object. - """ - - def __init__(self, data: bytes): - self._buffer = BytesIO(data) - - async def read(self, n=-1): - return self._buffer.read(n) - - async def seek(self, pos, whence=0): - return self._buffer.seek(pos, whence) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._buffer.close() - - def __getattr__(self, name): - return getattr(self._buffer, name) - - -class BatchRequest(BaseModel): - line_num: int - custom_id: str - method: str - url: str - body: dict[str, Any] - - -class ReferenceBatchesImpl(Batches): - """Reference implementation of the Batches API. - - This implementation processes batch files by making individual requests - to the inference API and generates output files with results. - """ - - def __init__( - self, - config: ReferenceBatchesImplConfig, - inference_api: Inference, - files_api: Files, - models_api: Models, - kvstore: KVStore, - ) -> None: - self.config = config - self.kvstore = kvstore - self.inference_api = inference_api - self.files_api = files_api - self.models_api = models_api - self._processing_tasks: dict[str, asyncio.Task] = {} - self._batch_semaphore = asyncio.Semaphore(config.max_concurrent_batches) - self._update_batch_lock = asyncio.Lock() - - # this is to allow tests to disable background processing - self.process_batches = True - - async def initialize(self) -> None: - # TODO: start background processing of existing tasks - pass - - async def shutdown(self) -> None: - """Shutdown the batches provider.""" - if self._processing_tasks: - # don't cancel tasks - just let them stop naturally on shutdown - # cancelling would mark batches as "cancelled" in the database - logger.info(f"Shutdown initiated with {len(self._processing_tasks)} active batch processing tasks") - - # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions - async def create_batch( - self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - ) -> BatchObject: - """ - Create a new batch for processing multiple API requests. - - Error handling by levels - - 0. Input param handling, results in 40x errors before processing, e.g. - - Wrong completion_window - - Invalid metadata types - - Unknown endpoint - -> no batch created - 1. Errors preventing processing, result in BatchErrors aggregated in process_batch, e.g. - - input_file_id missing - - invalid json in file - - missing custom_id, method, url, body - - invalid model - - streaming - -> batch created, validation sends to failed status - 2. Processing errors, result in error_file_id entries, e.g. - - Any error returned from inference endpoint - -> batch created, goes to completed status - """ - - # TODO: set expiration time for garbage collection - - if endpoint not in ["/v1/chat/completions"]: - raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", - ) - - if completion_window != "24h": - raise ValueError( - f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", - ) - - batch_id = f"batch_{uuid.uuid4().hex[:16]}" - current_time = int(time.time()) - - batch = BatchObject( - id=batch_id, - object="batch", - endpoint=endpoint, - input_file_id=input_file_id, - completion_window=completion_window, - status="validating", - created_at=current_time, - metadata=metadata, - ) - - await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) - - if self.process_batches: - task = asyncio.create_task(self._process_batch(batch_id)) - self._processing_tasks[batch_id] = task - - return batch - - async def cancel_batch(self, batch_id: str) -> BatchObject: - """Cancel a batch that is in progress.""" - batch = await self.retrieve_batch(batch_id) - - if batch.status in ["cancelled", "cancelling"]: - return batch - - if batch.status in ["completed", "failed", "expired"]: - raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") - - await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) - - if batch_id in self._processing_tasks: - self._processing_tasks[batch_id].cancel() - # note: task removal and status="cancelled" handled in finally block of _process_batch - - return await self.retrieve_batch(batch_id) - - async def list_batches( - self, - after: str | None = None, - limit: int = 20, - ) -> ListBatchesResponse: - """ - List all batches, eventually only for the current user. - - With no notion of user, we return all batches. - """ - batch_values = await self.kvstore.values_in_range("batch:", "batch:\xff") - - batches = [] - for batch_data in batch_values: - if batch_data: - batches.append(BatchObject.model_validate_json(batch_data)) - - batches.sort(key=lambda b: b.created_at, reverse=True) - - start_idx = 0 - if after: - for i, batch in enumerate(batches): - if batch.id == after: - start_idx = i + 1 - break - - page_batches = batches[start_idx : start_idx + limit] - has_more = (start_idx + limit) < len(batches) - - first_id = page_batches[0].id if page_batches else None - last_id = page_batches[-1].id if page_batches else None - - return ListBatchesResponse( - data=page_batches, - first_id=first_id, - last_id=last_id, - has_more=has_more, - ) - - async def retrieve_batch(self, batch_id: str) -> BatchObject: - """Retrieve information about a specific batch.""" - batch_data = await self.kvstore.get(f"batch:{batch_id}") - if not batch_data: - raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") - - return BatchObject.model_validate_json(batch_data) - - async def _update_batch(self, batch_id: str, **updates) -> None: - """Update batch fields in kvstore.""" - async with self._update_batch_lock: - try: - batch = await self.retrieve_batch(batch_id) - - # batch processing is async. once cancelling, only allow "cancelled" status updates - if batch.status == "cancelling" and updates.get("status") != "cancelled": - logger.info( - f"Skipping status update for cancelled batch {batch_id}: attempted {updates.get('status')}" - ) - return - - if "errors" in updates: - updates["errors"] = updates["errors"].model_dump() - - batch_dict = batch.model_dump() - batch_dict.update(updates) - - await self.kvstore.set(f"batch:{batch_id}", json.dumps(batch_dict)) - except Exception as e: - logger.error(f"Failed to update batch {batch_id}: {e}") - - async def _validate_input(self, batch: BatchObject) -> tuple[list[BatchError], list[BatchRequest]]: - """ - Read & validate input, return errors and valid input. - - Validation of - - input_file_id existance - - valid json - - custom_id, method, url, body presence and valid - - no streaming - """ - requests: list[BatchRequest] = [] - errors: list[BatchError] = [] - try: - await self.files_api.openai_retrieve_file(batch.input_file_id) - except Exception: - errors.append( - BatchError( - code="invalid_request", - line=None, - message=f"Cannot find file {batch.input_file_id}.", - param="input_file_id", - ) - ) - return errors, requests - - # TODO(SECURITY): do something about large files - file_content_response = await self.files_api.openai_retrieve_file_content(batch.input_file_id) - file_content = file_content_response.body.decode("utf-8") - for line_num, line in enumerate(file_content.strip().split("\n"), 1): - if line.strip(): # skip empty lines - try: - request = json.loads(line) - - if not isinstance(request, dict): - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message="Each line must be a JSON dictionary object", - ) - ) - continue - - valid = True - - for param, expected_type, type_string in [ - ("custom_id", str, "string"), - ("method", str, "string"), - ("url", str, "string"), - ("body", dict, "JSON dictionary object"), - ]: - if param not in request: - errors.append( - BatchError( - code="missing_required_parameter", - line=line_num, - message=f"Missing required parameter: {param}", - param=param, - ) - ) - valid = False - elif not isinstance(request[param], expected_type): - param_name = "URL" if param == "url" else param.capitalize() - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message=f"{param_name} must be a {type_string}", - param=param, - ) - ) - valid = False - - if (url := request.get("url")) and isinstance(url, str) and url != batch.endpoint: - errors.append( - BatchError( - code="invalid_url", - line=line_num, - message="URL provided for this request does not match the batch endpoint", - param="url", - ) - ) - valid = False - - if (body := request.get("body")) and isinstance(body, dict): - if body.get("stream", False): - errors.append( - BatchError( - code="streaming_unsupported", - line=line_num, - message="Streaming is not supported in batch processing", - param="body.stream", - ) - ) - valid = False - - for param, expected_type, type_string in [ - ("model", str, "a string"), - # messages is specific to /v1/chat/completions - # we could skip validating messages here and let inference fail. however, - # that would be a very expensive way to find out messages is wrong. - ("messages", list, "an array"), # TODO: allow messages to be a string? - ]: - if param not in body: - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message=f"{param.capitalize()} parameter is required", - param=f"body.{param}", - ) - ) - valid = False - elif not isinstance(body[param], expected_type): - errors.append( - BatchError( - code="invalid_request", - line=line_num, - message=f"{param.capitalize()} must be {type_string}", - param=f"body.{param}", - ) - ) - valid = False - - if "model" in body and isinstance(body["model"], str): - try: - await self.models_api.get_model(body["model"]) - except Exception: - errors.append( - BatchError( - code="model_not_found", - line=line_num, - message=f"Model '{body['model']}' does not exist or is not supported", - param="body.model", - ) - ) - valid = False - - if valid: - assert isinstance(url, str), "URL must be a string" # for mypy - assert isinstance(body, dict), "Body must be a dictionary" # for mypy - requests.append( - BatchRequest( - line_num=line_num, - url=url, - method=request["method"], - custom_id=request["custom_id"], - body=body, - ), - ) - except json.JSONDecodeError: - errors.append( - BatchError( - code="invalid_json_line", - line=line_num, - message="This line is not parseable as valid JSON.", - ) - ) - - return errors, requests - - async def _process_batch(self, batch_id: str) -> None: - """Background task to process a batch of requests.""" - try: - logger.info(f"Starting batch processing for {batch_id}") - async with self._batch_semaphore: # semaphore to limit concurrency - logger.info(f"Acquired semaphore for batch {batch_id}") - await self._process_batch_impl(batch_id) - except asyncio.CancelledError: - logger.info(f"Batch processing cancelled for {batch_id}") - await self._update_batch(batch_id, status="cancelled", cancelled_at=int(time.time())) - except Exception as e: - logger.error(f"Batch processing failed for {batch_id}: {e}") - await self._update_batch( - batch_id, - status="failed", - failed_at=int(time.time()), - errors=Errors(data=[BatchError(code="internal_error", message=str(e))]), - ) - finally: - self._processing_tasks.pop(batch_id, None) - - async def _process_batch_impl(self, batch_id: str) -> None: - """Implementation of batch processing logic.""" - errors: list[BatchError] = [] - batch = await self.retrieve_batch(batch_id) - - errors, requests = await self._validate_input(batch) - if errors: - await self._update_batch(batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors)) - logger.info(f"Batch validation failed for {batch_id} with {len(errors)} errors") - return - - logger.info(f"Processing {len(requests)} requests for batch {batch_id}") - - total_requests = len(requests) - await self._update_batch( - batch_id, - status="in_progress", - request_counts={"total": total_requests, "completed": 0, "failed": 0}, - ) - - error_results = [] - success_results = [] - completed_count = 0 - failed_count = 0 - - for chunk in itertools.batched(requests, self.config.max_concurrent_requests_per_batch): - # we use a TaskGroup to ensure all process-single-request tasks are canceled when process-batch is cancelled - async with asyncio.TaskGroup() as tg: - chunk_tasks = [tg.create_task(self._process_single_request(batch_id, request)) for request in chunk] - - chunk_results = await asyncio.gather(*chunk_tasks, return_exceptions=True) - - for result in chunk_results: - if isinstance(result, dict) and result.get("error") is not None: # error response from inference - failed_count += 1 - error_results.append(result) - elif isinstance(result, dict) and result.get("response") is not None: # successful inference - completed_count += 1 - success_results.append(result) - else: # unexpected result - failed_count += 1 - errors.append(BatchError(code="internal_error", message=f"Unexpected result: {result}")) - - await self._update_batch( - batch_id, - request_counts={"total": total_requests, "completed": completed_count, "failed": failed_count}, - ) - - if errors: - await self._update_batch( - batch_id, status="failed", failed_at=int(time.time()), errors=Errors(data=errors) - ) - return - - try: - output_file_id = await self._create_output_file(batch_id, success_results, "success") - await self._update_batch(batch_id, output_file_id=output_file_id) - - error_file_id = await self._create_output_file(batch_id, error_results, "error") - await self._update_batch(batch_id, error_file_id=error_file_id) - - await self._update_batch(batch_id, status="completed", completed_at=int(time.time())) - - logger.info( - f"Batch processing completed for {batch_id}: {completed_count} completed, {failed_count} failed" - ) - except Exception as e: - # note: errors is empty at this point, so we don't lose anything by ignoring it - await self._update_batch( - batch_id, - status="failed", - failed_at=int(time.time()), - errors=Errors(data=[BatchError(code="output_failed", message=str(e))]), - ) - - async def _process_single_request(self, batch_id: str, request: BatchRequest) -> dict: - """Process a single request from the batch.""" - request_id = f"batch_req_{batch_id}_{request.line_num}" - - try: - # TODO(SECURITY): review body for security issues - chat_response = await self.inference_api.openai_chat_completion(**request.body) - - # this is for mypy, we don't allow streaming so we'll get the right type - assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" - return { - "id": request_id, - "custom_id": request.custom_id, - "response": { - "status_code": 200, - "request_id": request_id, # TODO: should this be different? - "body": chat_response.model_dump_json(), - }, - } - except Exception as e: - logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") - return { - "id": request_id, - "custom_id": request.custom_id, - "error": {"type": "request_failed", "message": str(e)}, - } - - async def _create_output_file(self, batch_id: str, results: list[dict], file_type: str) -> str: - """ - Create an output file with batch results. - - This function filters results based on the specified file_type - and uploads the file to the Files API. - """ - output_lines = [json.dumps(result) for result in results] - - with AsyncBytesIO("\n".join(output_lines).encode("utf-8")) as file_buffer: - file_buffer.filename = f"{batch_id}_{file_type}.jsonl" - uploaded_file = await self.files_api.openai_upload_file(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) - return uploaded_file.id diff --git a/llama_stack/providers/inline/batches/reference/config.py b/llama_stack/providers/inline/batches/reference/config.py deleted file mode 100644 index d8d06868b..000000000 --- a/llama_stack/providers/inline/batches/reference/config.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from pydantic import BaseModel, Field - -from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig - - -class ReferenceBatchesImplConfig(BaseModel): - """Configuration for the Reference Batches implementation.""" - - kvstore: KVStoreConfig = Field( - description="Configuration for the key-value store backend.", - ) - - max_concurrent_batches: int = Field( - default=1, - description="Maximum number of concurrent batches to process simultaneously.", - ge=1, - ) - - max_concurrent_requests_per_batch: int = Field( - default=10, - description="Maximum number of concurrent requests to process per batch.", - ge=1, - ) - - # TODO: add a max requests per second rate limiter - - @classmethod - def sample_run_config(cls, __distro_dir__: str) -> dict: - return { - "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="batches.db", - ), - } diff --git a/llama_stack/providers/registry/batches.py b/llama_stack/providers/registry/batches.py deleted file mode 100644 index de7886efb..000000000 --- a/llama_stack/providers/registry/batches.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec - - -def available_providers() -> list[ProviderSpec]: - return [ - InlineProviderSpec( - api=Api.batches, - provider_type="inline::reference", - pip_packages=["openai"], - module="llama_stack.providers.inline.batches.reference", - config_class="llama_stack.providers.inline.batches.reference.config.ReferenceBatchesImplConfig", - api_dependencies=[ - Api.inference, - Api.files, - Api.models, - ], - description="Reference implementation of batches API with KVStore persistence.", - ), - ] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 9a77c8cc4..6297cc2ed 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -31,15 +31,15 @@ from openai.types.chat import ( from openai.types.chat import ( ChatCompletionContentPartTextParam as OpenAIChatCompletionContentPartTextParam, ) +from openai.types.chat import ( + ChatCompletionMessageFunctionToolCall as OpenAIChatCompletionMessageFunctionToolCall, +) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessage, ) from openai.types.chat import ( ChatCompletionMessageToolCall, ) -from openai.types.chat import ( - ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, -) from openai.types.chat import ( ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, ) @@ -633,7 +633,7 @@ async def convert_message_to_openai_dict_new( ) elif isinstance(message, CompletionMessage): tool_calls = [ - OpenAIChatCompletionMessageToolCall( + OpenAIChatCompletionMessageFunctionToolCall( id=tool.call_id, function=OpenAIFunction( name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), @@ -903,7 +903,7 @@ def _convert_openai_request_response_format( def _convert_openai_tool_calls( - tool_calls: list[OpenAIChatCompletionMessageToolCall], + tool_calls: list[OpenAIChatCompletionMessageFunctionToolCall], ) -> list[ToolCall]: """ Convert an OpenAI ChatCompletionMessageToolCall list into a list of ToolCall. diff --git a/scripts/provider_codegen.py b/scripts/provider_codegen.py index 060acfa72..717677c52 100755 --- a/scripts/provider_codegen.py +++ b/scripts/provider_codegen.py @@ -18,23 +18,6 @@ from llama_stack.core.distribution import get_provider_registry REPO_ROOT = Path(__file__).parent.parent -def get_api_docstring(api_name: str) -> str | None: - """Extract docstring from the API protocol class.""" - try: - # Import the API module dynamically - api_module = __import__(f"llama_stack.apis.{api_name}", fromlist=[api_name.title()]) - - # Get the main protocol class (usually capitalized API name) - protocol_class_name = api_name.title() - if hasattr(api_module, protocol_class_name): - protocol_class = getattr(api_module, protocol_class_name) - return protocol_class.__doc__ - except (ImportError, AttributeError): - pass - - return None - - class ChangedPathTracker: """Track a list of paths we may have changed.""" @@ -278,11 +261,6 @@ def process_provider_registry(progress, change_tracker: ChangedPathTracker) -> N index_content.append(f"# {api_name.title()}\n") index_content.append("## Overview\n") - api_docstring = get_api_docstring(api_name) - if api_docstring: - cleaned_docstring = api_docstring.strip() - index_content.append(f"{cleaned_docstring}\n") - index_content.append( f"This section contains documentation for all available providers for the **{api_name}** API.\n" ) diff --git a/tests/integration/batches/__init__.py b/tests/integration/batches/__init__.py deleted file mode 100644 index 756f351d8..000000000 --- a/tests/integration/batches/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. diff --git a/tests/integration/batches/conftest.py b/tests/integration/batches/conftest.py deleted file mode 100644 index 974fe77ab..000000000 --- a/tests/integration/batches/conftest.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -"""Shared pytest fixtures for batch tests.""" - -import json -import time -import warnings -from contextlib import contextmanager -from io import BytesIO - -import pytest - -from llama_stack.apis.files import OpenAIFilePurpose - - -class BatchHelper: - """Helper class for creating and managing batch input files.""" - - def __init__(self, client): - """Initialize with either a batch_client or openai_client.""" - self.client = client - - @contextmanager - def create_file(self, content: str | list[dict], filename_prefix="batch_input"): - """Context manager for creating and cleaning up batch input files. - - Args: - content: Either a list of batch request dictionaries or raw string content - filename_prefix: Prefix for the generated filename (or full filename if content is string) - - Yields: - The uploaded file object - """ - if isinstance(content, str): - # Handle raw string content (e.g., malformed JSONL, empty files) - file_content = content.encode("utf-8") - else: - # Handle list of batch request dictionaries - jsonl_content = "\n".join(json.dumps(req) for req in content) - file_content = jsonl_content.encode("utf-8") - - filename = filename_prefix if filename_prefix.endswith(".jsonl") else f"{filename_prefix}.jsonl" - - with BytesIO(file_content) as file_buffer: - file_buffer.name = filename - uploaded_file = self.client.files.create(file=file_buffer, purpose=OpenAIFilePurpose.BATCH) - - try: - yield uploaded_file - finally: - try: - self.client.files.delete(uploaded_file.id) - except Exception: - warnings.warn( - f"Failed to cleanup file {uploaded_file.id}: {uploaded_file.filename}", - stacklevel=2, - ) - - def wait_for( - self, - batch_id: str, - max_wait_time: int = 60, - sleep_interval: int | None = None, - expected_statuses: set[str] | None = None, - timeout_action: str = "fail", - ): - """Wait for a batch to reach a terminal status. - - Args: - batch_id: The batch ID to monitor - max_wait_time: Maximum time to wait in seconds (default: 60 seconds) - sleep_interval: Time to sleep between checks in seconds (default: 1/10th of max_wait_time, min 1s, max 15s) - expected_statuses: Set of expected terminal statuses (default: {"completed"}) - timeout_action: Action on timeout - "fail" (pytest.fail) or "skip" (pytest.skip) - - Returns: - The final batch object - - Raises: - pytest.Failed: If batch reaches an unexpected status or timeout_action is "fail" - pytest.Skipped: If timeout_action is "skip" on timeout or unexpected status - """ - if sleep_interval is None: - # Default to 1/10th of max_wait_time, with min 1s and max 15s - sleep_interval = max(1, min(15, max_wait_time // 10)) - - if expected_statuses is None: - expected_statuses = {"completed"} - - terminal_statuses = {"completed", "failed", "cancelled", "expired"} - unexpected_statuses = terminal_statuses - expected_statuses - - start_time = time.time() - while time.time() - start_time < max_wait_time: - current_batch = self.client.batches.retrieve(batch_id) - - if current_batch.status in expected_statuses: - return current_batch - elif current_batch.status in unexpected_statuses: - error_msg = f"Batch reached unexpected status: {current_batch.status}" - if timeout_action == "skip": - pytest.skip(error_msg) - else: - pytest.fail(error_msg) - - time.sleep(sleep_interval) - - timeout_msg = f"Batch did not reach expected status {expected_statuses} within {max_wait_time} seconds" - if timeout_action == "skip": - pytest.skip(timeout_msg) - else: - pytest.fail(timeout_msg) - - -@pytest.fixture -def batch_helper(openai_client): - """Fixture that provides a BatchHelper instance for OpenAI client.""" - return BatchHelper(openai_client) diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py deleted file mode 100644 index 1ef3202d0..000000000 --- a/tests/integration/batches/test_batches.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Integration tests for the Llama Stack batch processing functionality. - -This module contains comprehensive integration tests for the batch processing API, -using the OpenAI-compatible client interface for consistency. - -Test Categories: - 1. Core Batch Operations: - - test_batch_creation_and_retrieval: Comprehensive batch creation, structure validation, and retrieval - - test_batch_listing: Basic batch listing functionality - - test_batch_immediate_cancellation: Batch cancellation workflow - # TODO: cancel during processing - - 2. End-to-End Processing: - - test_batch_e2e_chat_completions: Full chat completions workflow with output and error validation - -Note: Error conditions and edge cases are primarily tested in test_batches_errors.py -for better organization and separation of concerns. - -CLEANUP WARNING: These tests currently create batches that are not automatically -cleaned up after test completion. This may lead to resource accumulation over -multiple test runs. Only test_batch_immediate_cancellation properly cancels its batch. -The test_batch_e2e_chat_completions test does clean up its output and error files. -""" - -import json - - -class TestBatchesIntegration: - """Integration tests for the batches API.""" - - def test_batch_creation_and_retrieval(self, openai_client, batch_helper, text_model_id): - """Test comprehensive batch creation and retrieval scenarios.""" - test_metadata = { - "test_type": "comprehensive", - "purpose": "creation_and_retrieval_test", - "version": "1.0", - "tags": "test,batch", - } - - batch_requests = [ - { - "custom_id": "request-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests, "batch_creation_test") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata=test_metadata, - ) - - assert batch.endpoint == "/v1/chat/completions" - assert batch.input_file_id == uploaded_file.id - assert batch.completion_window == "24h" - assert batch.metadata == test_metadata - - retrieved_batch = openai_client.batches.retrieve(batch.id) - - assert retrieved_batch.id == batch.id - assert retrieved_batch.object == batch.object - assert retrieved_batch.endpoint == batch.endpoint - assert retrieved_batch.input_file_id == batch.input_file_id - assert retrieved_batch.completion_window == batch.completion_window - assert retrieved_batch.metadata == batch.metadata - - def test_batch_listing(self, openai_client, batch_helper, text_model_id): - """ - Test batch listing. - - This test creates multiple batches and verifies that they can be listed. - It also deletes the input files before execution, which means the batches - will appear as failed due to missing input files. This is expected and - a good thing, because it means no inference is performed. - """ - batch_ids = [] - - for i in range(2): - batch_requests = [ - { - "custom_id": f"request-{i}", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": f"Hello {i}"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests, f"batch_input_{i}") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - batch_ids.append(batch.id) - - batch_list = openai_client.batches.list() - - assert isinstance(batch_list.data, list) - - listed_batch_ids = {b.id for b in batch_list.data} - for batch_id in batch_ids: - assert batch_id in listed_batch_ids - - def test_batch_immediate_cancellation(self, openai_client, batch_helper, text_model_id): - """Test immediate batch cancellation.""" - batch_requests = [ - { - "custom_id": "request-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - # hopefully cancel the batch before it completes - cancelling_batch = openai_client.batches.cancel(batch.id) - assert cancelling_batch.status in ["cancelling", "cancelled"] - assert isinstance(cancelling_batch.cancelling_at, int), ( - f"cancelling_at should be int, got {type(cancelling_batch.cancelling_at)}" - ) - - final_batch = batch_helper.wait_for( - batch.id, - max_wait_time=3 * 60, # often takes 10-11 minutes, give it 3 min - expected_statuses={"cancelled"}, - timeout_action="skip", - ) - - assert final_batch.status == "cancelled" - assert isinstance(final_batch.cancelled_at, int), ( - f"cancelled_at should be int, got {type(final_batch.cancelled_at)}" - ) - - def test_batch_e2e_chat_completions(self, openai_client, batch_helper, text_model_id): - """Test end-to-end batch processing for chat completions with both successful and failed operations.""" - batch_requests = [ - { - "custom_id": "success-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Say hello"}], - "max_tokens": 20, - }, - }, - { - "custom_id": "error-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "This should fail"}], - "max_tokens": -1, # Invalid negative max_tokens will cause inference error - }, - }, - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"test": "e2e_success_and_errors_test"}, - ) - - final_batch = batch_helper.wait_for( - batch.id, - max_wait_time=3 * 60, # often takes 2-3 minutes - expected_statuses={"completed"}, - timeout_action="skip", - ) - - # Expecting a completed batch with both successful and failed requests - # Batch(id='batch_xxx', - # completion_window='24h', - # created_at=..., - # endpoint='/v1/chat/completions', - # input_file_id='file-xxx', - # object='batch', - # status='completed', - # output_file_id='file-xxx', - # error_file_id='file-xxx', - # request_counts=BatchRequestCounts(completed=1, failed=1, total=2)) - - assert final_batch.status == "completed" - assert final_batch.request_counts is not None - assert final_batch.request_counts.total == 2 - assert final_batch.request_counts.completed == 1 - assert final_batch.request_counts.failed == 1 - - assert final_batch.output_file_id is not None, "Output file should exist for successful requests" - - output_content = openai_client.files.content(final_batch.output_file_id) - if isinstance(output_content, str): - output_text = output_content - else: - output_text = output_content.content.decode("utf-8") - - output_lines = output_text.strip().split("\n") - - for line in output_lines: - result = json.loads(line) - - assert "id" in result - assert "custom_id" in result - assert result["custom_id"] == "success-1" - - assert "response" in result - - assert result["response"]["status_code"] == 200 - assert "body" in result["response"] - assert "choices" in result["response"]["body"] - - assert final_batch.error_file_id is not None, "Error file should exist for failed requests" - - error_content = openai_client.files.content(final_batch.error_file_id) - if isinstance(error_content, str): - error_text = error_content - else: - error_text = error_content.content.decode("utf-8") - - error_lines = error_text.strip().split("\n") - - for line in error_lines: - result = json.loads(line) - - assert "id" in result - assert "custom_id" in result - assert result["custom_id"] == "error-1" - assert "error" in result - error = result["error"] - assert error is not None - assert "code" in error or "message" in error, "Error should have code or message" - - deleted_output_file = openai_client.files.delete(final_batch.output_file_id) - assert deleted_output_file.deleted, f"Output file {final_batch.output_file_id} was not deleted successfully" - - deleted_error_file = openai_client.files.delete(final_batch.error_file_id) - assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py deleted file mode 100644 index bc94a182e..000000000 --- a/tests/integration/batches/test_batches_errors.py +++ /dev/null @@ -1,693 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Error handling and edge case tests for the Llama Stack batch processing functionality. - -This module focuses exclusively on testing error conditions, validation failures, -and edge cases for batch operations to ensure robust error handling and graceful -degradation. - -Test Categories: - 1. File and Input Validation: - - test_batch_nonexistent_file_id: Handling invalid file IDs - - test_batch_malformed_jsonl: Processing malformed JSONL input files - - test_file_malformed_batch_file: Handling malformed files at upload time - - test_batch_missing_required_fields: Validation of required request fields - - 2. API Endpoint and Model Validation: - - test_batch_invalid_endpoint: Invalid endpoint handling during creation - - test_batch_error_handling_invalid_model: Error handling with nonexistent models - - test_batch_endpoint_mismatch: Validation of endpoint/URL consistency - - 3. Batch Lifecycle Error Handling: - - test_batch_retrieve_nonexistent: Retrieving non-existent batches - - test_batch_cancel_nonexistent: Cancelling non-existent batches - - test_batch_cancel_completed: Attempting to cancel completed batches - - 4. Parameter and Configuration Validation: - - test_batch_invalid_completion_window: Invalid completion window values - - test_batch_invalid_metadata_types: Invalid metadata type validation - - test_batch_missing_required_body_fields: Validation of required fields in request body - - 5. Feature Restriction and Compatibility: - - test_batch_streaming_not_supported: Streaming request rejection - - test_batch_mixed_streaming_requests: Mixed streaming/non-streaming validation - -Note: Core functionality and OpenAI compatibility tests are located in -test_batches_integration.py for better organization and separation of concerns. - -CLEANUP WARNING: These tests create batches to test error conditions but do not -automatically clean them up after test completion. While most error tests create -batches that fail quickly, some may create valid batches that consume resources. -""" - -import pytest -from openai import BadRequestError, ConflictError, NotFoundError - - -class TestBatchesErrorHandling: - """Error handling and edge case tests for the batches API using OpenAI client.""" - - def test_batch_nonexistent_file_id(self, openai_client, batch_helper): - """Test batch creation with nonexistent input file ID.""" - - batch = openai_client.batches.create( - input_file_id="file-nonexistent-xyz", - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='invalid_request', - # line=None, - # message='Cannot find file ..., or organization ... does not have access to it.', - # param='file_id') - # ], object='list'), - # failed_at=1754566971, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.code == "invalid_request" - assert "cannot find file" in error.message.lower() - - def test_batch_invalid_endpoint(self, openai_client, batch_helper, text_model_id): - """Test batch creation with invalid endpoint.""" - batch_requests = [ - { - "custom_id": "invalid-endpoint", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - with pytest.raises(BadRequestError) as exc_info: - openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/invalid/endpoint", - completion_window="24h", - ) - - # Expected - - # Error code: 400 - { - # 'error': { - # 'message': "Invalid value: '/v1/invalid/endpoint'. Supported values are: '/v1/chat/completions', '/v1/completions', '/v1/embeddings', and '/v1/responses'.", - # 'type': 'invalid_request_error', - # 'param': 'endpoint', - # 'code': 'invalid_value' - # } - # } - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 400 - assert "invalid value" in error_msg - assert "/v1/invalid/endpoint" in error_msg - assert "supported values" in error_msg - assert "endpoint" in error_msg - assert "invalid_value" in error_msg - - def test_batch_malformed_jsonl(self, openai_client, batch_helper): - """ - Test batch with malformed JSONL input. - - The /v1/files endpoint requires valid JSONL format, so we provide a well formed line - before a malformed line to ensure we get to the /v1/batches validation stage. - """ - with batch_helper.create_file( - """{"custom_id": "valid", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test"}} -{invalid json here""", - "malformed_batch_input.jsonl", - ) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # ..., - # BatchError(code='invalid_json_line', - # line=2, - # message='This line is not parseable as valid JSON.', - # param=None) - # ], object='list'), - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) > 0 - error = final_batch.errors.data[-1] # get last error because first may be about the "test" model - assert error.code == "invalid_json_line" - assert error.line == 2 - assert "not" in error.message.lower() - assert "valid json" in error.message.lower() - - @pytest.mark.xfail(reason="Not all file providers validate content") - @pytest.mark.parametrize("batch_requests", ["", "{malformed json"], ids=["empty", "malformed"]) - def test_file_malformed_batch_file(self, openai_client, batch_helper, batch_requests): - """Test file upload with malformed content.""" - - with pytest.raises(BadRequestError) as exc_info: - with batch_helper.create_file(batch_requests, "malformed_batch_input_file.jsonl"): - # /v1/files rejects the file, we don't get to batch creation - pass - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 400 - assert "invalid file format" in error_msg - assert "jsonl" in error_msg - - def test_batch_retrieve_nonexistent(self, openai_client): - """Test retrieving nonexistent batch.""" - with pytest.raises(NotFoundError) as exc_info: - openai_client.batches.retrieve("batch-nonexistent-xyz") - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 404 - assert "no batch found" in error_msg or "not found" in error_msg - - def test_batch_cancel_nonexistent(self, openai_client): - """Test cancelling nonexistent batch.""" - with pytest.raises(NotFoundError) as exc_info: - openai_client.batches.cancel("batch-nonexistent-xyz") - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 404 - assert "no batch found" in error_msg or "not found" in error_msg - - def test_batch_cancel_completed(self, openai_client, batch_helper, text_model_id): - """Test cancelling already completed batch.""" - batch_requests = [ - { - "custom_id": "cancel-completed", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Quick test"}], - "max_tokens": 5, - }, - } - ] - - with batch_helper.create_file(batch_requests, "cancel_test_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for( - batch.id, - max_wait_time=3 * 60, # often take 10-11 min, give it 3 min - expected_statuses={"completed"}, - timeout_action="skip", - ) - - deleted_file = openai_client.files.delete(final_batch.output_file_id) - assert deleted_file.deleted, f"File {final_batch.output_file_id} was not deleted successfully" - - with pytest.raises(ConflictError) as exc_info: - openai_client.batches.cancel(batch.id) - - # Expecting - - # Error code: 409 - { - # 'error': { - # 'message': "Cannot cancel a batch with status 'completed'.", - # 'type': 'invalid_request_error', - # 'param': None, - # 'code': None - # } - # } - # - # NOTE: Same for "failed", cancelling "cancelled" batches is allowed - - error_msg = str(exc_info.value).lower() - assert exc_info.value.status_code == 409 - assert "cannot cancel" in error_msg - - def test_batch_missing_required_fields(self, openai_client, batch_helper, text_model_id): - """Test batch with requests missing required fields.""" - batch_requests = [ - { - # Missing custom_id - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "No custom_id"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "no-method", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "No method"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "no-url", - "method": "POST", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "No URL"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "no-body", - "method": "POST", - "url": "/v1/chat/completions", - }, - ] - - with batch_helper.create_file(batch_requests, "missing_fields_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors( - # data=[ - # BatchError( - # code='missing_required_parameter', - # line=1, - # message="Missing required parameter: 'custom_id'.", - # param='custom_id' - # ), - # BatchError( - # code='missing_required_parameter', - # line=2, - # message="Missing required parameter: 'method'.", - # param='method' - # ), - # BatchError( - # code='missing_required_parameter', - # line=3, - # message="Missing required parameter: 'url'.", - # param='url' - # ), - # BatchError( - # code='missing_required_parameter', - # line=4, - # message="Missing required parameter: 'body'.", - # param='body' - # ) - # ], object='list'), - # failed_at=1754566945, - # ...) - # ) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 4 - no_custom_id_error = final_batch.errors.data[0] - assert no_custom_id_error.code == "missing_required_parameter" - assert no_custom_id_error.line == 1 - assert "missing" in no_custom_id_error.message.lower() - assert "custom_id" in no_custom_id_error.message.lower() - no_method_error = final_batch.errors.data[1] - assert no_method_error.code == "missing_required_parameter" - assert no_method_error.line == 2 - assert "missing" in no_method_error.message.lower() - assert "method" in no_method_error.message.lower() - no_url_error = final_batch.errors.data[2] - assert no_url_error.code == "missing_required_parameter" - assert no_url_error.line == 3 - assert "missing" in no_url_error.message.lower() - assert "url" in no_url_error.message.lower() - no_body_error = final_batch.errors.data[3] - assert no_body_error.code == "missing_required_parameter" - assert no_body_error.line == 4 - assert "missing" in no_body_error.message.lower() - assert "body" in no_body_error.message.lower() - - def test_batch_invalid_completion_window(self, openai_client, batch_helper, text_model_id): - """Test batch creation with invalid completion window.""" - batch_requests = [ - { - "custom_id": "invalid-completion-window", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - for window in ["1h", "48h", "invalid", ""]: - with pytest.raises(BadRequestError) as exc_info: - openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window=window, - ) - assert exc_info.value.status_code == 400 - error_msg = str(exc_info.value).lower() - assert "error" in error_msg - assert "completion_window" in error_msg - - def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id): - """Test that streaming responses are not supported in batches.""" - batch_requests = [ - { - "custom_id": "streaming-test", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - "stream": True, # Not supported - }, - } - ] - - with batch_helper.create_file(batch_requests, "streaming_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError(code='streaming_unsupported', - # line=1, - # message='Chat Completions: Streaming is not supported in the Batch API.', - # param='body.stream') - # ], object='list'), - # failed_at=1754566965, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.code == "streaming_unsupported" - assert error.line == 1 - assert "streaming" in error.message.lower() - assert "not supported" in error.message.lower() - assert error.param == "body.stream" - assert final_batch.failed_at is not None - - def test_batch_mixed_streaming_requests(self, openai_client, batch_helper, text_model_id): - """ - Test batch with mixed streaming and non-streaming requests. - - This is distinct from test_batch_streaming_not_supported, which tests a single - streaming request, to ensure an otherwise valid batch fails when a single - streaming request is included. - """ - batch_requests = [ - { - "custom_id": "valid-non-streaming-request", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello without streaming"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "streaming-request", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello with streaming"}], - "max_tokens": 10, - "stream": True, # Not supported - }, - }, - ] - - with batch_helper.create_file(batch_requests, "mixed_streaming_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='streaming_unsupported', - # line=2, - # message='Chat Completions: Streaming is not supported in the Batch API.', - # param='body.stream') - # ], object='list'), - # failed_at=1754574442, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.code == "streaming_unsupported" - assert error.line == 2 - assert "streaming" in error.message.lower() - assert "not supported" in error.message.lower() - assert error.param == "body.stream" - assert final_batch.failed_at is not None - - def test_batch_endpoint_mismatch(self, openai_client, batch_helper, text_model_id): - """Test batch creation with mismatched endpoint and request URL.""" - batch_requests = [ - { - "custom_id": "endpoint-mismatch", - "method": "POST", - "url": "/v1/embeddings", # Different from batch endpoint - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - }, - } - ] - - with batch_helper.create_file(batch_requests, "endpoint_mismatch_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", # Different from request URL - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='invalid_url', - # line=1, - # message='The URL provided for this request does not match the batch endpoint.', - # param='url') - # ], object='list'), - # failed_at=1754566972, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.line == 1 - assert error.code == "invalid_url" - assert "does not match" in error.message.lower() - assert "endpoint" in error.message.lower() - assert final_batch.failed_at is not None - - def test_batch_error_handling_invalid_model(self, openai_client, batch_helper): - """Test batch error handling with invalid model.""" - batch_requests = [ - { - "custom_id": "invalid-model", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": "nonexistent-model-xyz", - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError(code='model_not_found', - # line=1, - # message="The provided model 'nonexistent-model-xyz' is not supported by the Batch API.", - # param='body.model') - # ], object='list'), - # failed_at=1754566978, - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 1 - error = final_batch.errors.data[0] - assert error.line == 1 - assert error.code == "model_not_found" - assert "not supported" in error.message.lower() - assert error.param == "body.model" - assert final_batch.failed_at is not None - - def test_batch_missing_required_body_fields(self, openai_client, batch_helper, text_model_id): - """Test batch with requests missing required fields in body (model and messages).""" - batch_requests = [ - { - "custom_id": "missing-model", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - # Missing model field - "messages": [{"role": "user", "content": "Hello without model"}], - "max_tokens": 10, - }, - }, - { - "custom_id": "missing-messages", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - # Missing messages field - "max_tokens": 10, - }, - }, - ] - - with batch_helper.create_file(batch_requests, "missing_body_fields_batch_input") as uploaded_file: - batch = openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - ) - - final_batch = batch_helper.wait_for(batch.id, expected_statuses={"failed"}) - - # Expecting - - # Batch(..., - # status='failed', - # errors=Errors(data=[ - # BatchError( - # code='invalid_request', - # line=1, - # message='Model parameter is required.', - # param='body.model'), - # BatchError( - # code='invalid_request', - # line=2, - # message='Messages parameter is required.', - # param='body.messages') - # ], object='list'), - # ...) - - assert final_batch.status == "failed" - assert final_batch.errors is not None - assert len(final_batch.errors.data) == 2 - - model_error = final_batch.errors.data[0] - assert model_error.line == 1 - assert "model" in model_error.message.lower() - assert model_error.param == "body.model" - - messages_error = final_batch.errors.data[1] - assert messages_error.line == 2 - assert "messages" in messages_error.message.lower() - assert messages_error.param == "body.messages" - - assert final_batch.failed_at is not None - - def test_batch_invalid_metadata_types(self, openai_client, batch_helper, text_model_id): - """Test batch creation with invalid metadata types (like lists).""" - batch_requests = [ - { - "custom_id": "invalid-metadata-type", - "method": "POST", - "url": "/v1/chat/completions", - "body": { - "model": text_model_id, - "messages": [{"role": "user", "content": "Hello"}], - "max_tokens": 10, - }, - } - ] - - with batch_helper.create_file(batch_requests) as uploaded_file: - with pytest.raises(Exception) as exc_info: - openai_client.batches.create( - input_file_id=uploaded_file.id, - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={ - "tags": ["tag1", "tag2"], # Invalid type, should be a string - }, - ) - - # Expecting - - # Error code: 400 - {'error': - # {'message': "Invalid type for 'metadata.tags': expected a string, - # but got an array instead.", - # 'type': 'invalid_request_error', 'param': 'metadata.tags', - # 'code': 'invalid_type'}} - - error_msg = str(exc_info.value).lower() - assert "400" in error_msg - assert "tags" in error_msg - assert "string" in error_msg diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py deleted file mode 100644 index 9fe0cc710..000000000 --- a/tests/unit/providers/batches/test_reference.py +++ /dev/null @@ -1,753 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -""" -Test suite for the reference implementation of the Batches API. - -The tests are categorized and outlined below, keep this updated: - -- Batch creation with various parameters and validation: - * test_create_and_retrieve_batch_success (positive) - * test_create_batch_without_metadata (positive) - * test_create_batch_completion_window (negative) - * test_create_batch_invalid_endpoints (negative) - * test_create_batch_invalid_metadata (negative) - -- Batch retrieval and error handling for non-existent batches: - * test_retrieve_batch_not_found (negative) - -- Batch cancellation with proper status transitions: - * test_cancel_batch_success (positive) - * test_cancel_batch_invalid_statuses (negative) - * test_cancel_batch_not_found (negative) - -- Batch listing with pagination and filtering: - * test_list_batches_empty (positive) - * test_list_batches_single_batch (positive) - * test_list_batches_multiple_batches (positive) - * test_list_batches_with_limit (positive) - * test_list_batches_with_pagination (positive) - * test_list_batches_invalid_after (negative) - -- Data persistence in the underlying key-value store: - * test_kvstore_persistence (positive) - -- Batch processing concurrency control: - * test_max_concurrent_batches (positive) - -- Input validation testing (direct _validate_input method tests): - * test_validate_input_file_not_found (negative) - * test_validate_input_file_exists_empty_content (positive) - * test_validate_input_file_mixed_valid_invalid_json (mixed) - * test_validate_input_invalid_model (negative) - * test_validate_input_url_mismatch (negative) - * test_validate_input_multiple_errors_per_request (negative) - * test_validate_input_invalid_request_format (negative) - * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) - * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) - -The tests use temporary SQLite databases for isolation and mock external -dependencies like inference, files, and models APIs. -""" - -import json -import tempfile -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock - -import pytest - -from llama_stack.apis.batches import BatchObject -from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError -from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl -from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig - - -class TestReferenceBatchesImpl: - """Test the reference implementation of the Batches API.""" - - @pytest.fixture - async def provider(self): - """Create a test provider instance with temporary database.""" - with tempfile.TemporaryDirectory() as tmpdir: - db_path = Path(tmpdir) / "test_batches.db" - kvstore_config = SqliteKVStoreConfig(db_path=str(db_path)) - config = ReferenceBatchesImplConfig(kvstore=kvstore_config) - - # Create kvstore and mock APIs - from unittest.mock import AsyncMock - - from llama_stack.providers.utils.kvstore import kvstore_impl - - kvstore = await kvstore_impl(config.kvstore) - mock_inference = AsyncMock() - mock_files = AsyncMock() - mock_models = AsyncMock() - - provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore) - await provider.initialize() - - # unit tests should not require background processing - provider.process_batches = False - - yield provider - - await provider.shutdown() - - @pytest.fixture - def sample_batch_data(self): - """Sample batch data for testing.""" - return { - "input_file_id": "file_abc123", - "endpoint": "/v1/chat/completions", - "completion_window": "24h", - "metadata": {"test": "true", "priority": "high"}, - } - - def _validate_batch_type(self, batch, expected_metadata=None): - """ - Helper function to validate batch object structure and field types. - - Note: This validates the direct BatchObject from the provider, not the - client library response which has a different structure. - - Args: - batch: The BatchObject instance to validate. - expected_metadata: Optional expected metadata dictionary to validate against. - """ - assert isinstance(batch.id, str) - assert isinstance(batch.completion_window, str) - assert isinstance(batch.created_at, int) - assert isinstance(batch.endpoint, str) - assert isinstance(batch.input_file_id, str) - assert batch.object == "batch" - assert batch.status in [ - "validating", - "failed", - "in_progress", - "finalizing", - "completed", - "expired", - "cancelling", - "cancelled", - ] - - if expected_metadata is not None: - assert batch.metadata == expected_metadata - - timestamp_fields = [ - "cancelled_at", - "cancelling_at", - "completed_at", - "expired_at", - "expires_at", - "failed_at", - "finalizing_at", - "in_progress_at", - ] - for field in timestamp_fields: - field_value = getattr(batch, field, None) - if field_value is not None: - assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}" - - file_id_fields = ["error_file_id", "output_file_id"] - for field in file_id_fields: - field_value = getattr(batch, field, None) - if field_value is not None: - assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}" - - if hasattr(batch, "request_counts") and batch.request_counts is not None: - assert isinstance(batch.request_counts.completed, int), ( - f"request_counts.completed should be int, got {type(batch.request_counts.completed)}" - ) - assert isinstance(batch.request_counts.failed, int), ( - f"request_counts.failed should be int, got {type(batch.request_counts.failed)}" - ) - assert isinstance(batch.request_counts.total, int), ( - f"request_counts.total should be int, got {type(batch.request_counts.total)}" - ) - - if hasattr(batch, "errors") and batch.errors is not None: - assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}" - - if hasattr(batch.errors, "data") and batch.errors.data is not None: - assert isinstance(batch.errors.data, list), ( - f"errors.data should be list or None, got {type(batch.errors.data)}" - ) - - for i, error_item in enumerate(batch.errors.data): - assert isinstance(error_item, dict), ( - f"errors.data[{i}] should be object or dict, got {type(error_item)}" - ) - - if hasattr(error_item, "code") and error_item.code is not None: - assert isinstance(error_item.code, str), ( - f"errors.data[{i}].code should be str or None, got {type(error_item.code)}" - ) - - if hasattr(error_item, "line") and error_item.line is not None: - assert isinstance(error_item.line, int), ( - f"errors.data[{i}].line should be int or None, got {type(error_item.line)}" - ) - - if hasattr(error_item, "message") and error_item.message is not None: - assert isinstance(error_item.message, str), ( - f"errors.data[{i}].message should be str or None, got {type(error_item.message)}" - ) - - if hasattr(error_item, "param") and error_item.param is not None: - assert isinstance(error_item.param, str), ( - f"errors.data[{i}].param should be str or None, got {type(error_item.param)}" - ) - - if hasattr(batch.errors, "object") and batch.errors.object is not None: - assert isinstance(batch.errors.object, str), ( - f"errors.object should be str or None, got {type(batch.errors.object)}" - ) - assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}" - - async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): - """Test successful batch creation and retrieval.""" - created_batch = await provider.create_batch(**sample_batch_data) - - self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) - - assert created_batch.id.startswith("batch_") - assert len(created_batch.id) > 13 - assert created_batch.object == "batch" - assert created_batch.endpoint == sample_batch_data["endpoint"] - assert created_batch.input_file_id == sample_batch_data["input_file_id"] - assert created_batch.completion_window == sample_batch_data["completion_window"] - assert created_batch.status == "validating" - assert created_batch.metadata == sample_batch_data["metadata"] - assert isinstance(created_batch.created_at, int) - assert created_batch.created_at > 0 - - retrieved_batch = await provider.retrieve_batch(created_batch.id) - - self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) - - assert retrieved_batch.id == created_batch.id - assert retrieved_batch.input_file_id == created_batch.input_file_id - assert retrieved_batch.endpoint == created_batch.endpoint - assert retrieved_batch.status == created_batch.status - assert retrieved_batch.metadata == created_batch.metadata - - async def test_create_batch_without_metadata(self, provider): - """Test batch creation without optional metadata.""" - batch = await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" - ) - - assert batch.metadata is None - - async def test_create_batch_completion_window(self, provider): - """Test batch creation with invalid completion window.""" - with pytest.raises(ValueError, match="Invalid completion_window"): - await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now" - ) - - @pytest.mark.parametrize( - "endpoint", - [ - "/v1/embeddings", - "/v1/completions", - "/v1/invalid/endpoint", - "", - ], - ) - async def test_create_batch_invalid_endpoints(self, provider, endpoint): - """Test batch creation with various invalid endpoints.""" - with pytest.raises(ValueError, match="Invalid endpoint"): - await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") - - async def test_create_batch_invalid_metadata(self, provider): - """Test that batch creation fails with invalid metadata.""" - with pytest.raises(ValueError, match="should be a valid string"): - await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={123: "invalid_key"}, # Non-string key - ) - - with pytest.raises(ValueError, match="should be a valid string"): - await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"valid_key": 456}, # Non-string value - ) - - async def test_retrieve_batch_not_found(self, provider): - """Test error when retrieving non-existent batch.""" - with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.retrieve_batch("nonexistent_batch") - - async def test_cancel_batch_success(self, provider, sample_batch_data): - """Test successful batch cancellation.""" - created_batch = await provider.create_batch(**sample_batch_data) - assert created_batch.status == "validating" - - cancelled_batch = await provider.cancel_batch(created_batch.id) - - assert cancelled_batch.id == created_batch.id - assert cancelled_batch.status in ["cancelling", "cancelled"] - assert isinstance(cancelled_batch.cancelling_at, int) - assert cancelled_batch.cancelling_at >= created_batch.created_at - - @pytest.mark.parametrize("status", ["failed", "expired", "completed"]) - async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): - """Test error when cancelling batch in final states.""" - provider.process_batches = False - created_batch = await provider.create_batch(**sample_batch_data) - - # directly update status in kvstore - await provider._update_batch(created_batch.id, status=status) - - with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): - await provider.cancel_batch(created_batch.id) - - async def test_cancel_batch_not_found(self, provider): - """Test error when cancelling non-existent batch.""" - with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.cancel_batch("nonexistent_batch") - - async def test_list_batches_empty(self, provider): - """Test listing batches when none exist.""" - response = await provider.list_batches() - - assert response.object == "list" - assert response.data == [] - assert response.first_id is None - assert response.last_id is None - assert response.has_more is False - - async def test_list_batches_single_batch(self, provider, sample_batch_data): - """Test listing batches with single batch.""" - created_batch = await provider.create_batch(**sample_batch_data) - - response = await provider.list_batches() - - assert len(response.data) == 1 - self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) - assert response.data[0].id == created_batch.id - assert response.first_id == created_batch.id - assert response.last_id == created_batch.id - assert response.has_more is False - - async def test_list_batches_multiple_batches(self, provider): - """Test listing multiple batches.""" - batches = [ - await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" - ) - for i in range(3) - ] - - response = await provider.list_batches() - - assert len(response.data) == 3 - - batch_ids = {batch.id for batch in response.data} - expected_ids = {batch.id for batch in batches} - assert batch_ids == expected_ids - assert response.has_more is False - - assert response.first_id in expected_ids - assert response.last_id in expected_ids - - async def test_list_batches_with_limit(self, provider): - """Test listing batches with limit parameter.""" - batches = [ - await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" - ) - for i in range(3) - ] - - response = await provider.list_batches(limit=2) - - assert len(response.data) == 2 - assert response.has_more is True - assert response.first_id == response.data[0].id - assert response.last_id == response.data[1].id - batch_ids = {batch.id for batch in response.data} - expected_ids = {batch.id for batch in batches} - assert batch_ids.issubset(expected_ids) - - async def test_list_batches_with_pagination(self, provider): - """Test listing batches with pagination using 'after' parameter.""" - for i in range(3): - await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" - ) - - # Get first page - first_page = await provider.list_batches(limit=1) - assert len(first_page.data) == 1 - assert first_page.has_more is True - - # Get second page using 'after' - second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) - assert len(second_page.data) == 1 - assert second_page.data[0].id != first_page.data[0].id - - # Verify we got the next batch in order - all_batches = await provider.list_batches() - expected_second_batch_id = all_batches.data[1].id - assert second_page.data[0].id == expected_second_batch_id - - async def test_list_batches_invalid_after(self, provider, sample_batch_data): - """Test listing batches with invalid 'after' parameter.""" - await provider.create_batch(**sample_batch_data) - - response = await provider.list_batches(after="nonexistent_batch") - - # Should return all batches (no filtering when 'after' batch not found) - assert len(response.data) == 1 - - async def test_kvstore_persistence(self, provider, sample_batch_data): - """Test that batches are properly persisted in kvstore.""" - batch = await provider.create_batch(**sample_batch_data) - - stored_data = await provider.kvstore.get(f"batch:{batch.id}") - assert stored_data is not None - - stored_batch_dict = json.loads(stored_data) - assert stored_batch_dict["id"] == batch.id - assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"] - - async def test_validate_input_file_not_found(self, provider): - """Test _validate_input when input file does not exist.""" - provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found")) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="nonexistent_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - assert errors[0].code == "invalid_request" - assert errors[0].message == "Cannot find file nonexistent_file." - assert errors[0].param == "input_file_id" - assert errors[0].line is None - - async def test_validate_input_file_exists_empty_content(self, provider): - """Test _validate_input when file exists but is empty.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b"" - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="empty_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 0 - assert len(requests) == 0 - - async def test_validate_input_file_mixed_valid_invalid_json(self, provider): - """Test _validate_input when file contains valid and invalid JSON lines.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - # Line 1: valid JSON with proper body args, Line 2: invalid JSON - mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="mixed_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - # Should have 1 JSON parsing error from line 2, and 1 valid request from line 1 - assert len(errors) == 1 - assert len(requests) == 1 - - assert errors[0].code == "invalid_json_line" - assert errors[0].line == 2 - assert errors[0].message == "This line is not parseable as valid JSON." - - assert requests[0].custom_id == "req-1" - assert requests[0].method == "POST" - assert requests[0].url == "/v1/chat/completions" - assert requests[0].body["model"] == "test-model" - assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}] - - async def test_validate_input_invalid_model(self, provider): - """Test _validate_input when file contains request with non-existent model.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found")) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="invalid_model_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "model_not_found" - assert errors[0].line == 1 - assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported" - assert errors[0].param == "body.model" - - @pytest.mark.parametrize( - "param_name,param_path,error_code,error_message", - [ - ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), - ("method", "method", "missing_required_parameter", "Missing required parameter: method"), - ("url", "url", "missing_required_parameter", "Missing required parameter: url"), - ("body", "body", "missing_required_parameter", "Missing required parameter: body"), - ("model", "body.model", "invalid_request", "Model parameter is required"), - ("messages", "body.messages", "invalid_request", "Messages parameter is required"), - ], - ) - async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): - """Test _validate_input when file contains request with missing required parameters.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - - base_request = { - "custom_id": "req-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, - } - - # Remove the specific parameter being tested - if "." in param_path: - top_level, nested_param = param_path.split(".", 1) - del base_request[top_level][nested_param] - else: - del base_request[param_name] - - mock_response.body = json.dumps(base_request).encode() - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id=f"missing_{param_name}_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == error_code - assert errors[0].line == 1 - assert errors[0].message == error_message - assert errors[0].param == param_path - - async def test_validate_input_url_mismatch(self, provider): - """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", # This doesn't match the URL in the request - input_file_id="url_mismatch_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "invalid_url" - assert errors[0].line == 1 - assert errors[0].message == "URL provided for this request does not match the batch endpoint" - assert errors[0].param == "url" - - async def test_validate_input_multiple_errors_per_request(self, provider): - """Test _validate_input when a single request has multiple validation errors.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - # Request missing custom_id, has invalid URL, and missing model in body - mock_response.body = ( - b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}' - ) - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request - input_file_id="multiple_errors_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) >= 2 # At least missing custom_id and URL mismatch - assert len(requests) == 0 - - for error in errors: - assert error.line == 1 - - error_codes = {error.code for error in errors} - assert "missing_required_parameter" in error_codes # missing custom_id - assert "invalid_url" in error_codes # URL mismatch - - async def test_validate_input_invalid_request_format(self, provider): - """Test _validate_input when file contains non-object JSON (array, string, number).""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - mock_response.body = b'["not", "a", "request", "object"]' - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id="invalid_format_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "invalid_request" - assert errors[0].line == 1 - assert errors[0].message == "Each line must be a JSON dictionary object" - - @pytest.mark.parametrize( - "param_name,param_path,invalid_value,error_message", - [ - ("custom_id", "custom_id", 12345, "Custom_id must be a string"), - ("url", "url", 123, "URL must be a string"), - ("method", "method", ["POST"], "Method must be a string"), - ("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"), - ("model", "body.model", 123, "Model must be a string"), - ("messages", "body.messages", "invalid messages format", "Messages must be an array"), - ], - ) - async def test_validate_input_invalid_parameter_types( - self, provider, param_name, param_path, invalid_value, error_message - ): - """Test _validate_input when file contains request with parameters that have invalid types.""" - provider.files_api.openai_retrieve_file = AsyncMock() - mock_response = MagicMock() - - base_request = { - "custom_id": "req-1", - "method": "POST", - "url": "/v1/chat/completions", - "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}, - } - - # Override the specific parameter with invalid value - if "." in param_path: - top_level, nested_param = param_path.split(".", 1) - base_request[top_level][nested_param] = invalid_value - else: - base_request[param_name] = invalid_value - - mock_response.body = json.dumps(base_request).encode() - provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) - - batch = BatchObject( - id="batch_test", - object="batch", - endpoint="/v1/chat/completions", - input_file_id=f"invalid_{param_name}_type_file", - completion_window="24h", - status="validating", - created_at=1234567890, - ) - - errors, requests = await provider._validate_input(batch) - - assert len(errors) == 1 - assert len(requests) == 0 - - assert errors[0].code == "invalid_request" - assert errors[0].line == 1 - assert errors[0].message == error_message - assert errors[0].param == param_path - - async def test_max_concurrent_batches(self, provider): - """Test max_concurrent_batches configuration and concurrency control.""" - import asyncio - - provider._batch_semaphore = asyncio.Semaphore(2) - - provider.process_batches = True # enable because we're testing background processing - - active_batches = 0 - - async def add_and_wait(batch_id: str): - nonlocal active_batches - active_batches += 1 - await asyncio.sleep(float("inf")) - - # the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl, - # so we can replace _process_batch_impl with our mock to control concurrency - provider._process_batch_impl = add_and_wait - - for _ in range(3): - await provider.create_batch( - input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" - ) - - await asyncio.sleep(0.042) # let tasks start - - assert active_batches == 2, f"Expected 2 active batches, got {active_batches}" diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py index 5b8527d1b..ddc70e102 100644 --- a/tests/unit/providers/utils/inference/test_openai_compat.py +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -24,6 +24,7 @@ from llama_stack.apis.inference import ( from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, + convert_message_to_openai_dict_new, openai_messages_to_messages, ) @@ -182,3 +183,42 @@ def test_user_message_accepts_images(): assert len(msg.content) == 2 assert msg.content[0].text == "Describe this image:" assert msg.content[1].image_url.url == "http://example.com/image.jpg" + + +async def test_convert_message_to_openai_dict_new_user_message(): + """Test convert_message_to_openai_dict_new with UserMessage.""" + message = UserMessage(content="Hello, world!", role="user") + result = await convert_message_to_openai_dict_new(message) + + assert result["role"] == "user" + assert result["content"] == "Hello, world!" + + +async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls(): + """Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls.""" + message = CompletionMessage( + content="I'll help you find the weather.", + tool_calls=[ + ToolCall( + call_id="call_123", + tool_name="get_weather", + arguments={"city": "Sligo"}, + arguments_json='{"city": "Sligo"}', + ) + ], + stop_reason=StopReason.end_of_turn, + ) + result = await convert_message_to_openai_dict_new(message) + + # This would have failed with "Cannot instantiate typing.Union" before the fix + assert result["role"] == "assistant" + assert result["content"] == "I'll help you find the weather." + assert "tool_calls" in result + assert result["tool_calls"] is not None + assert len(result["tool_calls"]) == 1 + + tool_call = result["tool_calls"][0] + assert tool_call.id == "call_123" + assert tool_call.type == "function" + assert tool_call.function.name == "get_weather" + assert tool_call.function.arguments == '{"city": "Sligo"}'