chore: move dep functions outside of create_router

Less indirection and clearer declarations.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-24 11:30:44 +01:00
parent 49005f1a39
commit 87e60bc48f
No known key found for this signature in database

View file

@ -26,6 +26,30 @@ from llama_stack_api.router_utils import standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1 from llama_stack_api.version import LLAMA_STACK_API_V1
def get_retrieve_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
) -> RetrieveBatchRequest:
"""Dependency function to create RetrieveBatchRequest from path parameter."""
return RetrieveBatchRequest(batch_id=batch_id)
def get_cancel_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
) -> CancelBatchRequest:
"""Dependency function to create CancelBatchRequest from path parameter."""
return CancelBatchRequest(batch_id=batch_id)
def get_list_batches_request(
after: Annotated[
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
] = None,
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
) -> ListBatchesRequest:
"""Dependency function to create ListBatchesRequest from query parameters."""
return ListBatchesRequest(after=after, limit=limit)
def create_router(impl: Batches) -> APIRouter: def create_router(impl: Batches) -> APIRouter:
"""Create a FastAPI router for the Batches API. """Create a FastAPI router for the Batches API.
@ -41,10 +65,6 @@ def create_router(impl: Batches) -> APIRouter:
responses=standard_responses, responses=standard_responses,
) )
def get_batch_service() -> Batches:
"""Dependency function to get the batch service implementation."""
return impl
@router.post( @router.post(
"/batches", "/batches",
response_model=BatchObject, response_model=BatchObject,
@ -57,15 +77,8 @@ def create_router(impl: Batches) -> APIRouter:
) )
async def create_batch( async def create_batch(
request: Annotated[CreateBatchRequest, Body(...)], request: Annotated[CreateBatchRequest, Body(...)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject: ) -> BatchObject:
return await svc.create_batch(request) return await impl.create_batch(request)
def get_retrieve_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
) -> RetrieveBatchRequest:
"""Dependency function to create RetrieveBatchRequest from path parameter."""
return RetrieveBatchRequest(batch_id=batch_id)
@router.get( @router.get(
"/batches/{batch_id}", "/batches/{batch_id}",
@ -78,15 +91,8 @@ def create_router(impl: Batches) -> APIRouter:
) )
async def retrieve_batch( async def retrieve_batch(
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)], request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject: ) -> BatchObject:
return await svc.retrieve_batch(request) return await impl.retrieve_batch(request)
def get_cancel_batch_request(
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
) -> CancelBatchRequest:
"""Dependency function to create CancelBatchRequest from path parameter."""
return CancelBatchRequest(batch_id=batch_id)
@router.post( @router.post(
"/batches/{batch_id}/cancel", "/batches/{batch_id}/cancel",
@ -99,18 +105,8 @@ def create_router(impl: Batches) -> APIRouter:
) )
async def cancel_batch( async def cancel_batch(
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)], request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> BatchObject: ) -> BatchObject:
return await svc.cancel_batch(request) return await impl.cancel_batch(request)
def get_list_batches_request(
after: Annotated[
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
] = None,
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
) -> ListBatchesRequest:
"""Dependency function to create ListBatchesRequest from query parameters."""
return ListBatchesRequest(after=after, limit=limit)
@router.get( @router.get(
"/batches", "/batches",
@ -123,8 +119,7 @@ def create_router(impl: Batches) -> APIRouter:
) )
async def list_batches( async def list_batches(
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)], request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
svc: Annotated[Batches, Depends(get_batch_service)],
) -> ListBatchesResponse: ) -> ListBatchesResponse:
return await svc.list_batches(request) return await impl.list_batches(request)
return router return router