chore: refactor Batches protocol to use request models

This commit refactors the Batches protocol to use Pydantic request
models for both create_batch and list_batches methods, improving
consistency, readability, and maintainability.

- create_batch now accepts a single CreateBatchRequest parameter instead
  of individual arguments. This aligns the protocol with FastAPI’s
  request model pattern, allowing the router to pass the request object
  directly without unpacking parameters. Provider implementations now
  access fields via request.input_file_id, request.endpoint, etc.

- list_batches now accepts a single ListBatchesRequest parameter,
  replacing individual query parameters. The model includes after and
  limit fields with proper OpenAPI descriptions. FastAPI automatically
  parses query parameters into the model for GET requests, keeping
  router code clean. Provider implementations access fields via
  request.after and request.limit.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 15:54:07 +01:00
parent 00e7ea6c3b
commit 30cab02083
No known key found for this signature in database
9 changed files with 145 additions and 50 deletions

View file

@ -48,14 +48,18 @@ paths:
anyOf: anyOf:
- type: string - type: string
- type: 'null' - type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
title: After title: After
description: Optional cursor for pagination. Returns batches after this ID.
- name: limit - name: limit
in: query in: query
required: false required: false
schema: schema:
type: integer type: integer
description: Maximum number of batches to return. Defaults to 20.
default: 20 default: 20
title: Limit title: Limit
description: Maximum number of batches to return. Defaults to 20.
post: post:
responses: responses:
'200': '200':
@ -12690,6 +12694,22 @@ components:
- query - query
title: VectorStoreSearchRequest title: VectorStoreSearchRequest
type: object type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -9531,6 +9531,22 @@ components:
- query - query
title: VectorStoreSearchRequest title: VectorStoreSearchRequest
type: object type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -8544,6 +8544,22 @@ components:
- query - query
title: VectorStoreSearchRequest title: VectorStoreSearchRequest
type: object type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -46,14 +46,18 @@ paths:
anyOf: anyOf:
- type: string - type: string
- type: 'null' - type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
title: After title: After
description: Optional cursor for pagination. Returns batches after this ID.
- name: limit - name: limit
in: query in: query
required: false required: false
schema: schema:
type: integer type: integer
description: Maximum number of batches to return. Defaults to 20.
default: 20 default: 20
title: Limit title: Limit
description: Maximum number of batches to return. Defaults to 20.
post: post:
responses: responses:
'200': '200':
@ -11419,6 +11423,22 @@ components:
- query - query
title: VectorStoreSearchRequest title: VectorStoreSearchRequest
type: object type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -48,14 +48,18 @@ paths:
anyOf: anyOf:
- type: string - type: string
- type: 'null' - type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
title: After title: After
description: Optional cursor for pagination. Returns batches after this ID.
- name: limit - name: limit
in: query in: query
required: false required: false
schema: schema:
type: integer type: integer
description: Maximum number of batches to return. Defaults to 20.
default: 20 default: 20
title: Limit title: Limit
description: Maximum number of batches to return. Defaults to 20.
post: post:
responses: responses:
'200': '200':
@ -12690,6 +12694,22 @@ components:
- query - query
title: VectorStoreSearchRequest title: VectorStoreSearchRequest
type: object type: object
ListBatchesRequest:
description: Request model for listing batches.
properties:
after:
anyOf:
- type: string
- type: 'null'
description: Optional cursor for pagination. Returns batches after this ID.
nullable: true
limit:
default: 20
description: Maximum number of batches to return. Defaults to 20.
title: Limit
type: integer
title: ListBatchesRequest
type: object
DialogType: DialogType:
description: Parameter type for dialog data with semantic output labels. description: Parameter type for dialog data with semantic output labels.
properties: properties:

View file

@ -11,7 +11,7 @@ import json
import time import time
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Any, Literal from typing import Any
from openai.types.batch import BatchError, Errors from openai.types.batch import BatchError, Errors
from pydantic import BaseModel from pydantic import BaseModel
@ -38,6 +38,7 @@ from llama_stack_api import (
OpenAIUserMessageParam, OpenAIUserMessageParam,
ResourceNotFoundError, ResourceNotFoundError,
) )
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from .config import ReferenceBatchesImplConfig from .config import ReferenceBatchesImplConfig
@ -140,11 +141,7 @@ class ReferenceBatchesImpl(Batches):
# TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions
async def create_batch( async def create_batch(
self, self,
input_file_id: str, request: CreateBatchRequest,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject: ) -> BatchObject:
""" """
Create a new batch for processing multiple API requests. Create a new batch for processing multiple API requests.
@ -185,14 +182,14 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection # TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]:
raise ValueError( raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint",
) )
if completion_window != "24h": if request.completion_window != "24h":
raise ValueError( raise ValueError(
f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window",
) )
batch_id = f"batch_{uuid.uuid4().hex[:16]}" batch_id = f"batch_{uuid.uuid4().hex[:16]}"
@ -200,8 +197,8 @@ class ReferenceBatchesImpl(Batches):
# For idempotent requests, use the idempotency key for the batch ID # For idempotent requests, use the idempotency key for the batch ID
# This ensures the same key always maps to the same batch ID, # This ensures the same key always maps to the same batch ID,
# allowing us to detect parameter conflicts # allowing us to detect parameter conflicts
if idempotency_key is not None: if request.idempotency_key is not None:
hash_input = idempotency_key.encode("utf-8") hash_input = request.idempotency_key.encode("utf-8")
hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] hash_digest = hashlib.sha256(hash_input).hexdigest()[:24]
batch_id = f"batch_{hash_digest}" batch_id = f"batch_{hash_digest}"
@ -209,13 +206,13 @@ class ReferenceBatchesImpl(Batches):
existing_batch = await self.retrieve_batch(batch_id) existing_batch = await self.retrieve_batch(batch_id)
if ( if (
existing_batch.input_file_id != input_file_id existing_batch.input_file_id != request.input_file_id
or existing_batch.endpoint != endpoint or existing_batch.endpoint != request.endpoint
or existing_batch.completion_window != completion_window or existing_batch.completion_window != request.completion_window
or existing_batch.metadata != metadata or existing_batch.metadata != request.metadata
): ):
raise ConflictError( raise ConflictError(
f"Idempotency key '{idempotency_key}' was previously used with different parameters. " f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. "
"Either use a new idempotency key or ensure all parameters match the original request." "Either use a new idempotency key or ensure all parameters match the original request."
) )
@ -230,12 +227,12 @@ class ReferenceBatchesImpl(Batches):
batch = BatchObject( batch = BatchObject(
id=batch_id, id=batch_id,
object="batch", object="batch",
endpoint=endpoint, endpoint=request.endpoint,
input_file_id=input_file_id, input_file_id=request.input_file_id,
completion_window=completion_window, completion_window=request.completion_window,
status="validating", status="validating",
created_at=current_time, created_at=current_time,
metadata=metadata, metadata=request.metadata,
) )
await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) await self.kvstore.set(f"batch:{batch_id}", batch.to_json())
@ -267,8 +264,7 @@ class ReferenceBatchesImpl(Batches):
async def list_batches( async def list_batches(
self, self,
after: str | None = None, request: ListBatchesRequest,
limit: int = 20,
) -> ListBatchesResponse: ) -> ListBatchesResponse:
""" """
List all batches, eventually only for the current user. List all batches, eventually only for the current user.
@ -285,14 +281,14 @@ class ReferenceBatchesImpl(Batches):
batches.sort(key=lambda b: b.created_at, reverse=True) batches.sort(key=lambda b: b.created_at, reverse=True)
start_idx = 0 start_idx = 0
if after: if request.after:
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
if batch.id == after: if batch.id == request.after:
start_idx = i + 1 start_idx = i + 1
break break
page_batches = batches[start_idx : start_idx + limit] page_batches = batches[start_idx : start_idx + request.limit]
has_more = (start_idx + limit) < len(batches) has_more = (start_idx + request.limit) < len(batches)
first_id = page_batches[0].id if page_batches else None first_id = page_batches[0].id if page_batches else None
last_id = page_batches[-1].id if page_batches else None last_id = page_batches[-1].id if page_batches else None

View file

@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models.
The FastAPI router is defined in llama_stack_api.batches.routes. The FastAPI router is defined in llama_stack_api.batches.routes.
""" """
from typing import Literal, Protocol, runtime_checkable from typing import Protocol, runtime_checkable
try: try:
from openai.types import Batch as BatchObject from openai.types import Batch as BatchObject
@ -19,7 +19,7 @@ except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
# Import models for re-export # Import models for re-export
from llama_stack_api.batches.models import ListBatchesResponse from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
@runtime_checkable @runtime_checkable
@ -39,11 +39,7 @@ class Batches(Protocol):
async def create_batch( async def create_batch(
self, self,
input_file_id: str, request: CreateBatchRequest,
endpoint: str,
completion_window: Literal["24h"],
metadata: dict[str, str] | None = None,
idempotency_key: str | None = None,
) -> BatchObject: ... ) -> BatchObject: ...
async def retrieve_batch(self, batch_id: str) -> BatchObject: ... async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
@ -52,9 +48,8 @@ class Batches(Protocol):
async def list_batches( async def list_batches(
self, self,
after: str | None = None, request: ListBatchesRequest,
limit: int = 20,
) -> ListBatchesResponse: ... ) -> ListBatchesResponse: ...
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] __all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]

View file

@ -37,6 +37,16 @@ class CreateBatchRequest(BaseModel):
) )
@json_schema_type
class ListBatchesRequest(BaseModel):
"""Request model for listing batches."""
after: str | None = Field(
default=None, description="Optional cursor for pagination. Returns batches after this ID."
)
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
@json_schema_type @json_schema_type
class ListBatchesResponse(BaseModel): class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects.""" """Response containing a list of batch objects."""
@ -48,4 +58,4 @@ class ListBatchesResponse(BaseModel):
has_more: bool = Field(default=False, description="Whether there are more batches available") has_more: bool = Field(default=False, description="Whether there are more batches available")
__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"] __all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]

View file

@ -14,10 +14,10 @@ all API-related code together.
from collections.abc import Callable from collections.abc import Callable
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Body, Depends from fastapi import APIRouter, Body, Depends, Query
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from llama_stack_api.datatypes import Api from llama_stack_api.datatypes import Api
from llama_stack_api.router_utils import standard_responses from llama_stack_api.router_utils import standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1 from llama_stack_api.version import LLAMA_STACK_API_V1
@ -56,13 +56,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
request: Annotated[CreateBatchRequest, Body(...)], request: Annotated[CreateBatchRequest, Body(...)],
svc: Annotated[Batches, Depends(get_batch_service)], svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject: ) -> BatchObject:
return await svc.create_batch( return await svc.create_batch(request)
input_file_id=request.input_file_id,
endpoint=request.endpoint,
completion_window=request.completion_window,
metadata=request.metadata,
idempotency_key=request.idempotency_key,
)
@router.get( @router.get(
"/batches/{batch_id}", "/batches/{batch_id}",
@ -94,6 +88,15 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
) -> BatchObject: ) -> BatchObject:
return await svc.cancel_batch(batch_id) return await svc.cancel_batch(batch_id)
def get_list_batches_request(
after: Annotated[
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
] = None,
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
) -> ListBatchesRequest:
"""Dependency function to create ListBatchesRequest from query parameters."""
return ListBatchesRequest(after=after, limit=limit)
@router.get( @router.get(
"/batches", "/batches",
response_model=ListBatchesResponse, response_model=ListBatchesResponse,
@ -104,10 +107,9 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
}, },
) )
async def list_batches( async def list_batches(
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
svc: Annotated[Batches, Depends(get_batch_service)], svc: Annotated[Batches, Depends(get_batch_service)],
after: str | None = None,
limit: int = 20,
) -> ListBatchesResponse: ) -> ListBatchesResponse:
return await svc.list_batches(after=after, limit=limit) return await svc.list_batches(request)
return router return router