From 30cab020835ae0e2472be0090bf53d03cca5f75d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Nov 2025 15:54:07 +0100 Subject: [PATCH] chore: refactor Batches protocol to use request models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors the Batches protocol to use Pydantic request models for both create_batch and list_batches methods, improving consistency, readability, and maintainability. - create_batch now accepts a single CreateBatchRequest parameter instead of individual arguments. This aligns the protocol with FastAPI’s request model pattern, allowing the router to pass the request object directly without unpacking parameters. Provider implementations now access fields via request.input_file_id, request.endpoint, etc. - list_batches now accepts a single ListBatchesRequest parameter, replacing individual query parameters. The model includes after and limit fields with proper OpenAPI descriptions. FastAPI automatically parses query parameters into the model for GET requests, keeping router code clean. Provider implementations access fields via request.after and request.limit. Signed-off-by: Sébastien Han --- client-sdks/stainless/openapi.yml | 20 ++++++++ docs/static/deprecated-llama-stack-spec.yaml | 16 ++++++ .../static/experimental-llama-stack-spec.yaml | 16 ++++++ docs/static/llama-stack-spec.yaml | 20 ++++++++ docs/static/stainless-llama-stack-spec.yaml | 20 ++++++++ .../inline/batches/reference/batches.py | 50 +++++++++---------- src/llama_stack_api/batches/__init__.py | 15 ++---- src/llama_stack_api/batches/models.py | 12 ++++- src/llama_stack_api/batches/routes.py | 26 +++++----- 9 files changed, 145 insertions(+), 50 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index be941f652..c477ae32c 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -12690,6 +12694,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 94b1a69a7..2e517da70 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -9531,6 +9531,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index dfd354544..21112924c 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -8544,6 +8544,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index a736fc8f9..5335e69b8 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -46,14 +46,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -11419,6 +11423,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index be941f652..c477ae32c 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -12690,6 +12694,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index 73727799d..aaa105f28 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -11,7 +11,7 @@ import json import time import uuid from io import BytesIO -from typing import Any, Literal +from typing import Any from openai.types.batch import BatchError, Errors from pydantic import BaseModel @@ -38,6 +38,7 @@ from llama_stack_api import ( OpenAIUserMessageParam, ResourceNotFoundError, ) +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from .config import ReferenceBatchesImplConfig @@ -140,11 +141,7 @@ class ReferenceBatchesImpl(Batches): # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions async def create_batch( self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, + request: CreateBatchRequest, ) -> BatchObject: """ Create a new batch for processing multiple API requests. @@ -185,14 +182,14 @@ class ReferenceBatchesImpl(Batches): # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: + if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", ) - if completion_window != "24h": + if request.completion_window != "24h": raise ValueError( - f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", ) batch_id = f"batch_{uuid.uuid4().hex[:16]}" @@ -200,8 +197,8 @@ class ReferenceBatchesImpl(Batches): # For idempotent requests, use the idempotency key for the batch ID # This ensures the same key always maps to the same batch ID, # allowing us to detect parameter conflicts - if idempotency_key is not None: - hash_input = idempotency_key.encode("utf-8") + if request.idempotency_key is not None: + hash_input = request.idempotency_key.encode("utf-8") hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] batch_id = f"batch_{hash_digest}" @@ -209,13 +206,13 @@ class ReferenceBatchesImpl(Batches): existing_batch = await self.retrieve_batch(batch_id) if ( - existing_batch.input_file_id != input_file_id - or existing_batch.endpoint != endpoint - or existing_batch.completion_window != completion_window - or existing_batch.metadata != metadata + existing_batch.input_file_id != request.input_file_id + or existing_batch.endpoint != request.endpoint + or existing_batch.completion_window != request.completion_window + or existing_batch.metadata != request.metadata ): raise ConflictError( - f"Idempotency key '{idempotency_key}' was previously used with different parameters. " + f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. " "Either use a new idempotency key or ensure all parameters match the original request." ) @@ -230,12 +227,12 @@ class ReferenceBatchesImpl(Batches): batch = BatchObject( id=batch_id, object="batch", - endpoint=endpoint, - input_file_id=input_file_id, - completion_window=completion_window, + endpoint=request.endpoint, + input_file_id=request.input_file_id, + completion_window=request.completion_window, status="validating", created_at=current_time, - metadata=metadata, + metadata=request.metadata, ) await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) @@ -267,8 +264,7 @@ class ReferenceBatchesImpl(Batches): async def list_batches( self, - after: str | None = None, - limit: int = 20, + request: ListBatchesRequest, ) -> ListBatchesResponse: """ List all batches, eventually only for the current user. @@ -285,14 +281,14 @@ class ReferenceBatchesImpl(Batches): batches.sort(key=lambda b: b.created_at, reverse=True) start_idx = 0 - if after: + if request.after: for i, batch in enumerate(batches): - if batch.id == after: + if batch.id == request.after: start_idx = i + 1 break - page_batches = batches[start_idx : start_idx + limit] - has_more = (start_idx + limit) < len(batches) + page_batches = batches[start_idx : start_idx + request.limit] + has_more = (start_idx + request.limit) < len(batches) first_id = page_batches[0].id if page_batches else None last_id = page_batches[-1].id if page_batches else None diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py index 6f778de8e..3d301598c 100644 --- a/src/llama_stack_api/batches/__init__.py +++ b/src/llama_stack_api/batches/__init__.py @@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models. The FastAPI router is defined in llama_stack_api.batches.routes. """ -from typing import Literal, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable try: from openai.types import Batch as BatchObject @@ -19,7 +19,7 @@ except ImportError as e: raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e # Import models for re-export -from llama_stack_api.batches.models import ListBatchesResponse +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse @runtime_checkable @@ -39,11 +39,7 @@ class Batches(Protocol): async def create_batch( self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, + request: CreateBatchRequest, ) -> BatchObject: ... async def retrieve_batch(self, batch_id: str) -> BatchObject: ... @@ -52,9 +48,8 @@ class Batches(Protocol): async def list_batches( self, - after: str | None = None, - limit: int = 20, + request: ListBatchesRequest, ) -> ListBatchesResponse: ... -__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] +__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"] diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py index fe449280d..bb6d7e3d0 100644 --- a/src/llama_stack_api/batches/models.py +++ b/src/llama_stack_api/batches/models.py @@ -37,6 +37,16 @@ class CreateBatchRequest(BaseModel): ) +@json_schema_type +class ListBatchesRequest(BaseModel): + """Request model for listing batches.""" + + after: str | None = Field( + default=None, description="Optional cursor for pagination. Returns batches after this ID." + ) + limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.") + + @json_schema_type class ListBatchesResponse(BaseModel): """Response containing a list of batch objects.""" @@ -48,4 +58,4 @@ class ListBatchesResponse(BaseModel): has_more: bool = Field(default=False, description="Whether there are more batches available") -__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"] +__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"] diff --git a/src/llama_stack_api/batches/routes.py b/src/llama_stack_api/batches/routes.py index e8b6aaf41..adbc10be5 100644 --- a/src/llama_stack_api/batches/routes.py +++ b/src/llama_stack_api/batches/routes.py @@ -14,10 +14,10 @@ all API-related code together. from collections.abc import Callable from typing import Annotated -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, Query from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse -from llama_stack_api.batches.models import CreateBatchRequest +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from llama_stack_api.datatypes import Api from llama_stack_api.router_utils import standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 @@ -56,13 +56,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: request: Annotated[CreateBatchRequest, Body(...)], svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.create_batch( - input_file_id=request.input_file_id, - endpoint=request.endpoint, - completion_window=request.completion_window, - metadata=request.metadata, - idempotency_key=request.idempotency_key, - ) + return await svc.create_batch(request) @router.get( "/batches/{batch_id}", @@ -94,6 +88,15 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: ) -> BatchObject: return await svc.cancel_batch(batch_id) + def get_list_batches_request( + after: Annotated[ + str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") + ] = None, + limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, + ) -> ListBatchesRequest: + """Dependency function to create ListBatchesRequest from query parameters.""" + return ListBatchesRequest(after=after, limit=limit) + @router.get( "/batches", response_model=ListBatchesResponse, @@ -104,10 +107,9 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: }, ) async def list_batches( + request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)], svc: Annotated[Batches, Depends(get_batch_service)], - after: str | None = None, - limit: int = 20, ) -> ListBatchesResponse: - return await svc.list_batches(after=after, limit=limit) + return await svc.list_batches(request) return router