diff --git a/llama_stack/apis/batches/__init__.py b/llama_stack/apis/batches/__init__.py index d3efe3dba..9ce7d3d75 100644 --- a/llama_stack/apis/batches/__init__.py +++ b/llama_stack/apis/batches/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .batches import Batches, BatchObject, CreateBatchRequest, ListBatchesResponse +from .batches import Batches, BatchObject, ListBatchesResponse -__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesResponse"] +__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/llama_stack/apis/batches/batches.py b/llama_stack/apis/batches/batches.py index 81ab44ccd..9297d8597 100644 --- a/llama_stack/apis/batches/batches.py +++ b/llama_stack/apis/batches/batches.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Literal, Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -16,16 +16,6 @@ except ImportError as e: raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e -@json_schema_type -class CreateBatchRequest(BaseModel): - """Request to create a new batch.""" - - input_file_id: str = Field(..., description="The ID of an uploaded file that contains requests for the new batch") - endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch") - completion_window: str = Field(..., description="The time window within which the batch should be processed") - metadata: dict[str, Any] | None = Field(default=None, description="Optional metadata for the batch") - - @json_schema_type class ListBatchesResponse(BaseModel): """Response containing a list of batch objects.""" @@ -53,7 +43,7 @@ class Batches(Protocol): self, input_file_id: str, endpoint: str, - completion_window: str, + completion_window: Literal["24h"], metadata: dict[str, str] | None = None, ) -> BatchObject: """Create a new batch for processing multiple API requests. diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 6e99a00a1..984ef5a90 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -10,7 +10,7 @@ import json import time import uuid from io import BytesIO -from typing import Any +from typing import Any, Literal from openai.types.batch import BatchError, Errors from pydantic import BaseModel @@ -108,7 +108,7 @@ class ReferenceBatchesImpl(Batches): self, input_file_id: str, endpoint: str, - completion_window: str, + completion_window: Literal["24h"], metadata: dict[str, str] | None = None, ) -> BatchObject: """ diff --git a/tests/integration/batches/test_batches_errors.py b/tests/integration/batches/test_batches_errors.py index 2cd1e561e..bc94a182e 100644 --- a/tests/integration/batches/test_batches_errors.py +++ b/tests/integration/batches/test_batches_errors.py @@ -379,9 +379,8 @@ class TestBatchesErrorHandling: ) assert exc_info.value.status_code == 400 error_msg = str(exc_info.value).lower() - assert "invalid value" in error_msg + assert "error" in error_msg assert "completion_window" in error_msg - assert "supported values are" in error_msg def test_batch_streaming_not_supported(self, openai_client, batch_helper, text_model_id): """Test that streaming responses are not supported in batches."""