chore: refactor Batches protocol to use request models

This commit refactors the Batches protocol to use Pydantic request
models for both create_batch and list_batches methods, improving
consistency, readability, and maintainability.

- create_batch now accepts a single CreateBatchRequest parameter instead
  of individual arguments. This aligns the protocol with FastAPI’s
  request model pattern, allowing the router to pass the request object
  directly without unpacking parameters. Provider implementations now
  access fields via request.input_file_id, request.endpoint, etc.

- list_batches now accepts a single ListBatchesRequest parameter,
  replacing individual query parameters. The model includes after and
  limit fields with proper OpenAPI descriptions. FastAPI automatically
  parses query parameters into the model for GET requests, keeping
  router code clean. Provider implementations access fields via
  request.after and request.limit.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 15:54:07 +01:00
parent 00e7ea6c3b
commit 30cab02083
No known key found for this signature in database
9 changed files with 145 additions and 50 deletions

View file

@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models.
The FastAPI router is defined in llama_stack_api.batches.routes.
"""
from typing import Literal, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
try:
from openai.types import Batch as BatchObject
@ -19,7 +19,7 @@ except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
# Import models for re-export
from llama_stack_api.batches.models import ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
@runtime_checkable
@ -39,11 +39,7 @@ class Batches(Protocol):
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
request: CreateBatchRequest,
) -> BatchObject: ...
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
@ -52,9 +48,8 @@ class Batches(Protocol):
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
request: ListBatchesRequest,
) -> ListBatchesResponse: ...
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]