mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
chore: refactor Batches protocol to use request models
This commit refactors the Batches protocol to use Pydantic request models for both create_batch and list_batches methods, improving consistency, readability, and maintainability. - create_batch now accepts a single CreateBatchRequest parameter instead of individual arguments. This aligns the protocol with FastAPI’s request model pattern, allowing the router to pass the request object directly without unpacking parameters. Provider implementations now access fields via request.input_file_id, request.endpoint, etc. - list_batches now accepts a single ListBatchesRequest parameter, replacing individual query parameters. The model includes after and limit fields with proper OpenAPI descriptions. FastAPI automatically parses query parameters into the model for GET requests, keeping router code clean. Provider implementations access fields via request.after and request.limit. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
00e7ea6c3b
commit
30cab02083
9 changed files with 145 additions and 50 deletions
|
|
@ -11,7 +11,7 @@ Pydantic models are defined in llama_stack_api.batches.models.
|
|||
The FastAPI router is defined in llama_stack_api.batches.routes.
|
||||
"""
|
||||
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
try:
|
||||
from openai.types import Batch as BatchObject
|
||||
|
|
@ -19,7 +19,7 @@ except ImportError as e:
|
|||
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
|
||||
|
||||
# Import models for re-export
|
||||
from llama_stack_api.batches.models import ListBatchesResponse
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
@ -39,11 +39,7 @@ class Batches(Protocol):
|
|||
|
||||
async def create_batch(
|
||||
self,
|
||||
input_file_id: str,
|
||||
endpoint: str,
|
||||
completion_window: Literal["24h"],
|
||||
metadata: dict[str, str] | None = None,
|
||||
idempotency_key: str | None = None,
|
||||
request: CreateBatchRequest,
|
||||
) -> BatchObject: ...
|
||||
|
||||
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
|
||||
|
|
@ -52,9 +48,8 @@ class Batches(Protocol):
|
|||
|
||||
async def list_batches(
|
||||
self,
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
request: ListBatchesRequest,
|
||||
) -> ListBatchesResponse: ...
|
||||
|
||||
|
||||
__all__ = ["Batches", "BatchObject", "ListBatchesResponse"]
|
||||
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]
|
||||
|
|
|
|||
|
|
@ -37,6 +37,16 @@ class CreateBatchRequest(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesRequest(BaseModel):
|
||||
"""Request model for listing batches."""
|
||||
|
||||
after: str | None = Field(
|
||||
default=None, description="Optional cursor for pagination. Returns batches after this ID."
|
||||
)
|
||||
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListBatchesResponse(BaseModel):
|
||||
"""Response containing a list of batch objects."""
|
||||
|
|
@ -48,4 +58,4 @@ class ListBatchesResponse(BaseModel):
|
|||
has_more: bool = Field(default=False, description="Whether there are more batches available")
|
||||
|
||||
|
||||
__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"]
|
||||
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ all API-related code together.
|
|||
from collections.abc import Callable
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from fastapi import APIRouter, Body, Depends, Query
|
||||
|
||||
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
|
||||
from llama_stack_api.batches.models import CreateBatchRequest
|
||||
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
|
||||
from llama_stack_api.datatypes import Api
|
||||
from llama_stack_api.router_utils import standard_responses
|
||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
|
@ -56,13 +56,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
request: Annotated[CreateBatchRequest, Body(...)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.create_batch(
|
||||
input_file_id=request.input_file_id,
|
||||
endpoint=request.endpoint,
|
||||
completion_window=request.completion_window,
|
||||
metadata=request.metadata,
|
||||
idempotency_key=request.idempotency_key,
|
||||
)
|
||||
return await svc.create_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
|
|
@ -94,6 +88,15 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
) -> BatchObject:
|
||||
return await svc.cancel_batch(batch_id)
|
||||
|
||||
def get_list_batches_request(
|
||||
after: Annotated[
|
||||
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
|
||||
] = None,
|
||||
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
|
||||
) -> ListBatchesRequest:
|
||||
"""Dependency function to create ListBatchesRequest from query parameters."""
|
||||
return ListBatchesRequest(after=after, limit=limit)
|
||||
|
||||
@router.get(
|
||||
"/batches",
|
||||
response_model=ListBatchesResponse,
|
||||
|
|
@ -104,10 +107,9 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
|
|||
},
|
||||
)
|
||||
async def list_batches(
|
||||
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
after: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> ListBatchesResponse:
|
||||
return await svc.list_batches(after=after, limit=limit)
|
||||
return await svc.list_batches(request)
|
||||
|
||||
return router
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue