chore: refactor Batches protocol to use request models

This commit refactors the Batches protocol to use Pydantic request
models for both create_batch and list_batches methods, improving
consistency, readability, and maintainability.

- create_batch now accepts a single CreateBatchRequest parameter instead
  of individual arguments. This aligns the protocol with FastAPI’s
  request model pattern, allowing the router to pass the request object
  directly without unpacking parameters. Provider implementations now
  access fields via request.input_file_id, request.endpoint, etc.

- list_batches now accepts a single ListBatchesRequest parameter,
  replacing individual query parameters. The model includes after and
  limit fields with proper OpenAPI descriptions. FastAPI automatically
  parses query parameters into the model for GET requests, keeping
  router code clean. Provider implementations access fields via
  request.after and request.limit.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 15:54:07 +01:00
parent 00e7ea6c3b
commit 30cab02083
No known key found for this signature in database
9 changed files with 145 additions and 50 deletions

View file

@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models.
The FastAPI router is defined in llama_stack_api.batches.routes.
"""
from typing import Literal, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
try:
from openai.types import Batch as BatchObject
@ -19,7 +19,7 @@ except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
# Import models for re-export
from llama_stack_api.batches.models import ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
@runtime_checkable
@ -39,11 +39,7 @@ class Batches(Protocol):
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
request: CreateBatchRequest,
) -> BatchObject: ...
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
@ -52,9 +48,8 @@ class Batches(Protocol):
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
request: ListBatchesRequest,
) -> ListBatchesResponse: ...
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]

View file

@ -37,6 +37,16 @@ class CreateBatchRequest(BaseModel):
)
@json_schema_type
class ListBatchesRequest(BaseModel):
"""Request model for listing batches."""
after: str | None = Field(
default=None, description="Optional cursor for pagination. Returns batches after this ID."
)
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
@ -48,4 +58,4 @@ class ListBatchesResponse(BaseModel):
has_more: bool = Field(default=False, description="Whether there are more batches available")
__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"]
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]

View file

@ -14,10 +14,10 @@ all API-related code together.
from collections.abc import Callable
from typing import Annotated
from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, Query
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from llama_stack_api.datatypes import Api
from llama_stack_api.router_utils import standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1
@ -56,13 +56,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
request: Annotated[CreateBatchRequest, Body(...)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject:
return await svc.create_batch(
input_file_id=request.input_file_id,
endpoint=request.endpoint,
completion_window=request.completion_window,
metadata=request.metadata,
idempotency_key=request.idempotency_key,
)
return await svc.create_batch(request)
@router.get(
"/batches/{batch_id}",
@ -94,6 +88,15 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
) -> BatchObject:
return await svc.cancel_batch(batch_id)
def get_list_batches_request(
after: Annotated[
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
] = None,
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
) -> ListBatchesRequest:
"""Dependency function to create ListBatchesRequest from query parameters."""
return ListBatchesRequest(after=after, limit=limit)
@router.get(
"/batches",
response_model=ListBatchesResponse,
@ -104,10 +107,9 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
},
)
async def list_batches(
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
return await svc.list_batches(after=after, limit=limit)
return await svc.list_batches(request)
return router