chore: refactor Batches protocol to use request models

This commit refactors the Batches protocol to use Pydantic request
models for both create_batch and list_batches methods, improving
consistency, readability, and maintainability.

- create_batch now accepts a single CreateBatchRequest parameter instead
  of individual arguments. This aligns the protocol with FastAPI’s
  request model pattern, allowing the router to pass the request object
  directly without unpacking parameters. Provider implementations now
  access fields via request.input_file_id, request.endpoint, etc.

- list_batches now accepts a single ListBatchesRequest parameter,
  replacing individual query parameters. The model includes after and
  limit fields with proper OpenAPI descriptions. FastAPI automatically
  parses query parameters into the model for GET requests, keeping
  router code clean. Provider implementations access fields via
  request.after and request.limit.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 15:54:07 +01:00
parent 00e7ea6c3b
commit 30cab02083
No known key found for this signature in database
9 changed files with 145 additions and 50 deletions

View file

@ -48,14 +48,18 @@ paths:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
title: After
description: Optional cursor for pagination. Returns batches after this ID.
- name: limit
in: query
required: false
schema:
type: integer
description: Maximum number of batches to return. Defaults to 20.
default: 20
title: Limit
description: Maximum number of batches to return. Defaults to 20.
post:
responses:
'200':
@ -12690,6 +12694,22 @@ components:
- query
title: VectorStoreSearchRequest
type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -9531,6 +9531,22 @@ components:
- query
title: VectorStoreSearchRequest
type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -8544,6 +8544,22 @@ components:
- query
title: VectorStoreSearchRequest
type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -46,14 +46,18 @@ paths:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
title: After
description: Optional cursor for pagination. Returns batches after this ID.
- name: limit
in: query
required: false
schema:
type: integer
description: Maximum number of batches to return. Defaults to 20.
default: 20
title: Limit
description: Maximum number of batches to return. Defaults to 20.
post:
responses:
'200':
@ -11419,6 +11423,22 @@ components:
- query
title: VectorStoreSearchRequest
type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -48,14 +48,18 @@ paths:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
title: After
description: Optional cursor for pagination. Returns batches after this ID.
- name: limit
in: query
required: false
schema:
type: integer
description: Maximum number of batches to return. Defaults to 20.
default: 20
title: Limit
description: Maximum number of batches to return. Defaults to 20.
post:
responses:
'200':
@ -12690,6 +12694,22 @@ components:
- query
title: VectorStoreSearchRequest
type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType:
description: Parameter type for dialog data with semantic output labels.
properties:

View file

@ -11,7 +11,7 @@ import json
import time
import uuid
from io import BytesIO
from typing import Any, Literal
from typing import Any
from openai.types.batch import BatchError, Errors
from pydantic import BaseModel
@ -38,6 +38,7 @@ from llama_stack_api import (
OpenAIUserMessageParam,
ResourceNotFoundError,
)
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from .config import ReferenceBatchesImplConfig
@ -140,11 +141,7 @@ class ReferenceBatchesImpl(Batches):
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
request: CreateBatchRequest,
) -> BatchObject:
"""
Create a new batch for processing multiple API requests.
@ -185,14 +182,14 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
if request.completion_window != "24h":
raise ValueError(
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
)
batch_id = f"batch_{uuid.uuid4().hex[:16]}"
@ -200,8 +197,8 @@ class ReferenceBatchesImpl(Batches):
# For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts
if idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8")
if request.idempotency_key is not None:
hash_input = request.idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}"
@ -209,13 +206,13 @@ class ReferenceBatchesImpl(Batches):
existing_batch = await self.retrieve_batch(batch_id)
if (
existing_batch.input_file_id != input_file_id
or existing_batch.endpoint != endpoint
or existing_batch.completion_window != completion_window
or existing_batch.metadata != metadata
existing_batch.input_file_id != request.input_file_id
or existing_batch.endpoint != request.endpoint
or existing_batch.completion_window != request.completion_window
or existing_batch.metadata != request.metadata
):
raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. "
f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request."
)
@ -230,12 +227,12 @@ class ReferenceBatchesImpl(Batches):
batch = BatchObject(
id=batch_id,
object="batch",
endpoint=endpoint,
input_file_id=input_file_id,
completion_window=completion_window,
endpoint=request.endpoint,
input_file_id=request.input_file_id,
completion_window=request.completion_window,
status="validating",
created_at=current_time,
metadata=metadata,
metadata=request.metadata,
)
await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
@ -267,8 +264,7 @@ class ReferenceBatchesImpl(Batches):
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
request: ListBatchesRequest,
) -> ListBatchesResponse:
"""
List all batches, eventually only for the current user.
@ -285,14 +281,14 @@ class ReferenceBatchesImpl(Batches):
batches.sort(key=lambda b: b.created_at, reverse=True)
start_idx = 0
if after:
if request.after:
for i, batch in enumerate(batches):
if batch.id == after:
if batch.id == request.after:
start_idx = i + 1
break
page_batches = batches[start_idx : start_idx + limit]
has_more = (start_idx + limit) < len(batches)
page_batches = batches[start_idx : start_idx + request.limit]
has_more = (start_idx + request.limit) < len(batches)
first_id = page_batches[0].id if page_batches else None
last_id = page_batches[-1].id if page_batches else None

View file

@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models.
The FastAPI router is defined in llama_stack_api.batches.routes.
"""
from typing import Literal, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
try:
from openai.types import Batch as BatchObject
@ -19,7 +19,7 @@ except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
# Import models for re-export
from llama_stack_api.batches.models import ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
@runtime_checkable
@ -39,11 +39,7 @@ class Batches(Protocol):
async def create_batch(
self,
input_file_id: str,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
request: CreateBatchRequest,
) -> BatchObject: ...
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
@ -52,9 +48,8 @@ class Batches(Protocol):
async def list_batches(
self,
after: str | None = None,
limit: int = 20,
request: ListBatchesRequest,
) -> ListBatchesResponse: ...
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]

View file

@ -37,6 +37,16 @@ class CreateBatchRequest(BaseModel):
)
@json_schema_type
class ListBatchesRequest(BaseModel):
"""Request model for listing batches."""
after: str | None = Field(
default=None, description="Optional cursor for pagination. Returns batches after this ID."
)
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
@ -48,4 +58,4 @@ class ListBatchesResponse(BaseModel):
has_more: bool = Field(default=False, description="Whether there are more batches available")
__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"]
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]

View file

@ -14,10 +14,10 @@ all API-related code together.
from collections.abc import Callable
from typing import Annotated
from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, Query
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from llama_stack_api.datatypes import Api
from llama_stack_api.router_utils import standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1
@ -56,13 +56,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
request: Annotated[CreateBatchRequest, Body(...)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject:
return await svc.create_batch(
input_file_id=request.input_file_id,
endpoint=request.endpoint,
completion_window=request.completion_window,
metadata=request.metadata,
idempotency_key=request.idempotency_key,
)
return await svc.create_batch(request)
@router.get(
"/batches/{batch_id}",
@ -94,6 +88,15 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
) -> BatchObject:
return await svc.cancel_batch(batch_id)
def get_list_batches_request(
after: Annotated[
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
] = None,
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
) -> ListBatchesRequest:
"""Dependency function to create ListBatchesRequest from query parameters."""
return ListBatchesRequest(after=after, limit=limit)
@router.get(
"/batches",
response_model=ListBatchesResponse,
@ -104,10 +107,9 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
},
)
async def list_batches(
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse:
return await svc.list_batches(after=after, limit=limit)
return await svc.list_batches(request)
return router