chore: same as previous commit but for more fields

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-20 16:12:52 +01:00
parent 30cab02083
commit 20030429e7
No known key found for this signature in database
9 changed files with 210 additions and 25 deletions

View file

@ -19,7 +19,13 @@ except ImportError as e:
raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e
# Import models for re-export
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse
from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
ListBatchesResponse,
RetrieveBatchRequest,
)
@runtime_checkable
@ -42,9 +48,15 @@ class Batches(Protocol):
request: CreateBatchRequest,
) -> BatchObject: ...
async def retrieve_batch(self, batch_id: str) -> BatchObject: ...
async def retrieve_batch(
self,
request: RetrieveBatchRequest,
) -> BatchObject: ...
async def cancel_batch(self, batch_id: str) -> BatchObject: ...
async def cancel_batch(
self,
request: CancelBatchRequest,
) -> BatchObject: ...
async def list_batches(
self,
@ -52,4 +64,12 @@ class Batches(Protocol):
) -> ListBatchesResponse: ...
__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"]
__all__ = [
"Batches",
"BatchObject",
"CreateBatchRequest",
"ListBatchesRequest",
"RetrieveBatchRequest",
"CancelBatchRequest",
"ListBatchesResponse",
]

View file

@ -47,6 +47,20 @@ class ListBatchesRequest(BaseModel):
limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.")
@json_schema_type
class RetrieveBatchRequest(BaseModel):
"""Request model for retrieving a batch."""
batch_id: str = Field(..., description="The ID of the batch to retrieve.")
@json_schema_type
class CancelBatchRequest(BaseModel):
"""Request model for canceling a batch."""
batch_id: str = Field(..., description="The ID of the batch to cancel.")
@json_schema_type
class ListBatchesResponse(BaseModel):
"""Response containing a list of batch objects."""
@ -58,4 +72,11 @@ class ListBatchesResponse(BaseModel):
has_more: bool = Field(default=False, description="Whether there are more batches available")
__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"]
__all__ = [
"CreateBatchRequest",
"ListBatchesRequest",
"RetrieveBatchRequest",
"CancelBatchRequest",
"ListBatchesResponse",
"BatchObject",
]

View file

@ -14,10 +14,15 @@ all API-related code together.
from collections.abc import Callable
from typing import Annotated
from fastapi import APIRouter, Body, Depends, Query
from fastapi import APIRouter, Body, Depends, Path, Query
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest
from llama_stack_api.batches.models import (
CancelBatchRequest,
CreateBatchRequest,
ListBatchesRequest,
RetrieveBatchRequest,
)
from llama_stack_api.datatypes import Api
from llama_stack_api.router_utils import standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1
@ -58,6 +63,12 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
) -> BatchObject:
return await svc.create_batch(request)
def get_retrieve_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
) -> RetrieveBatchRequest:
"""Dependency function to create RetrieveBatchRequest from path parameter."""
return RetrieveBatchRequest(batch_id=batch_id)
@router.get(
"/batches/{batch_id}",
response_model=BatchObject,
@ -68,10 +79,16 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
},
)
async def retrieve_batch(
batch_id: str,
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject:
return await svc.retrieve_batch(batch_id)
return await svc.retrieve_batch(request)
def get_cancel_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
) -> CancelBatchRequest:
"""Dependency function to create CancelBatchRequest from path parameter."""
return CancelBatchRequest(batch_id=batch_id)
@router.post(
"/batches/{batch_id}/cancel",
@ -83,10 +100,10 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter:
},
)
async def cancel_batch(
batch_id: str,
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject:
return await svc.cancel_batch(batch_id)
return await svc.cancel_batch(request)
def get_list_batches_request(
after: Annotated[