mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
chore: move dep functions outside of create_router
Less indirection and clearer declarations. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
49005f1a39
commit
87e60bc48f
1 changed files with 28 additions and 33 deletions
|
|
@ -26,6 +26,30 @@ from llama_stack_api.router_utils import standard_responses
|
|||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||
|
||||
|
||||
def get_retrieve_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
|
||||
) -> RetrieveBatchRequest:
|
||||
"""Dependency function to create RetrieveBatchRequest from path parameter."""
|
||||
return RetrieveBatchRequest(batch_id=batch_id)
|
||||
|
||||
|
||||
def get_cancel_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
|
||||
) -> CancelBatchRequest:
|
||||
"""Dependency function to create CancelBatchRequest from path parameter."""
|
||||
return CancelBatchRequest(batch_id=batch_id)
|
||||
|
||||
|
||||
def get_list_batches_request(
|
||||
after: Annotated[
|
||||
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
|
||||
] = None,
|
||||
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
|
||||
) -> ListBatchesRequest:
|
||||
"""Dependency function to create ListBatchesRequest from query parameters."""
|
||||
return ListBatchesRequest(after=after, limit=limit)
|
||||
|
||||
|
||||
def create_router(impl: Batches) -> APIRouter:
|
||||
"""Create a FastAPI router for the Batches API.
|
||||
|
||||
|
|
@ -41,10 +65,6 @@ def create_router(impl: Batches) -> APIRouter:
|
|||
responses=standard_responses,
|
||||
)
|
||||
|
||||
def get_batch_service() -> Batches:
|
||||
"""Dependency function to get the batch service implementation."""
|
||||
return impl
|
||||
|
||||
@router.post(
|
||||
"/batches",
|
||||
response_model=BatchObject,
|
||||
|
|
@ -57,15 +77,8 @@ def create_router(impl: Batches) -> APIRouter:
|
|||
)
|
||||
async def create_batch(
|
||||
request: Annotated[CreateBatchRequest, Body(...)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.create_batch(request)
|
||||
|
||||
def get_retrieve_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
|
||||
) -> RetrieveBatchRequest:
|
||||
"""Dependency function to create RetrieveBatchRequest from path parameter."""
|
||||
return RetrieveBatchRequest(batch_id=batch_id)
|
||||
return await impl.create_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches/{batch_id}",
|
||||
|
|
@ -78,15 +91,8 @@ def create_router(impl: Batches) -> APIRouter:
|
|||
)
|
||||
async def retrieve_batch(
|
||||
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.retrieve_batch(request)
|
||||
|
||||
def get_cancel_batch_request(
|
||||
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
|
||||
) -> CancelBatchRequest:
|
||||
"""Dependency function to create CancelBatchRequest from path parameter."""
|
||||
return CancelBatchRequest(batch_id=batch_id)
|
||||
return await impl.retrieve_batch(request)
|
||||
|
||||
@router.post(
|
||||
"/batches/{batch_id}/cancel",
|
||||
|
|
@ -99,18 +105,8 @@ def create_router(impl: Batches) -> APIRouter:
|
|||
)
|
||||
async def cancel_batch(
|
||||
request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> BatchObject:
|
||||
return await svc.cancel_batch(request)
|
||||
|
||||
def get_list_batches_request(
|
||||
after: Annotated[
|
||||
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
|
||||
] = None,
|
||||
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
|
||||
) -> ListBatchesRequest:
|
||||
"""Dependency function to create ListBatchesRequest from query parameters."""
|
||||
return ListBatchesRequest(after=after, limit=limit)
|
||||
return await impl.cancel_batch(request)
|
||||
|
||||
@router.get(
|
||||
"/batches",
|
||||
|
|
@ -123,8 +119,7 @@ def create_router(impl: Batches) -> APIRouter:
|
|||
)
|
||||
async def list_batches(
|
||||
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)],
|
||||
svc: Annotated[Batches, Depends(get_batch_service)],
|
||||
) -> ListBatchesResponse:
|
||||
return await svc.list_batches(request)
|
||||
return await impl.list_batches(request)
|
||||
|
||||
return router
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue