diff --git a/src/llama_stack_api/batches/fastapi_routes.py b/src/llama_stack_api/batches/fastapi_routes.py index b53a4fc03..dd5dc7a6c 100644 --- a/src/llama_stack_api/batches/fastapi_routes.py +++ b/src/llama_stack_api/batches/fastapi_routes.py @@ -13,7 +13,7 @@ all API-related code together. from typing import Annotated -from fastapi import APIRouter, Body, Depends, Path, Query +from fastapi import APIRouter, Body, Depends from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse from llama_stack_api.batches.models import ( @@ -22,32 +22,18 @@ from llama_stack_api.batches.models import ( ListBatchesRequest, RetrieveBatchRequest, ) -from llama_stack_api.router_utils import standard_responses +from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 - -def get_retrieve_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")], -) -> RetrieveBatchRequest: - """Dependency function to create RetrieveBatchRequest from path parameter.""" - return RetrieveBatchRequest(batch_id=batch_id) +# Automatically generate dependency functions from Pydantic models +# This ensures the models are the single source of truth for descriptions +get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) +get_cancel_batch_request = create_path_dependency(CancelBatchRequest) -def get_cancel_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")], -) -> CancelBatchRequest: - """Dependency function to create CancelBatchRequest from path parameter.""" - return CancelBatchRequest(batch_id=batch_id) - - -def get_list_batches_request( - after: Annotated[ - str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") - ] = None, - limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, -) -> ListBatchesRequest: - """Dependency function to create ListBatchesRequest from query parameters.""" - return ListBatchesRequest(after=after, limit=limit) +# Automatically generate dependency function from Pydantic model +# This ensures the model is the single source of truth for descriptions and defaults +get_list_batches_request = create_query_dependency(ListBatchesRequest) def create_router(impl: Batches) -> APIRouter: diff --git a/src/llama_stack_api/router_utils.py b/src/llama_stack_api/router_utils.py index fd0efe060..5d934826c 100644 --- a/src/llama_stack_api/router_utils.py +++ b/src/llama_stack_api/router_utils.py @@ -11,7 +11,12 @@ These responses use OpenAPI $ref references to component responses defined in the OpenAPI specification. """ -from typing import Any +import inspect +from collections.abc import Callable +from typing import Annotated, Any, TypeVar + +from fastapi import Path, Query +from pydantic import BaseModel standard_responses: dict[int | str, dict[str, Any]] = { 400: {"$ref": "#/components/responses/BadRequest400"}, @@ -19,3 +24,146 @@ standard_responses: dict[int | str, dict[str, Any]] = { 500: {"$ref": "#/components/responses/InternalServerError500"}, "default": {"$ref": "#/components/responses/DefaultError"}, } + +T = TypeVar("T", bound=BaseModel) + + +def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for query parameters. + + FastAPI does not natively support using Pydantic models as query parameters + without a dependency function. Using a dependency function typically leads to + duplication: field types, default values, and descriptions must be repeated in + `Query(...)` annotations even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts query parameters + from the request and constructs an instance of the Pydantic model. The descriptions and + defaults are automatically extracted from the model's Field definitions, making the model + the single source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for + + Returns: + A dependency function that can be used with FastAPI's Depends() + + Example: + ```python + get_list_batches_request = create_query_dependency(ListBatchesRequest) + + @router.get("/batches") + async def list_batches( + request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)] + ): + ... + ``` + """ + # Build function signature dynamically from model fields + annotations: dict[str, Any] = {} + defaults: dict[str, Any] = {} + + for field_name, field_info in model_class.model_fields.items(): + # Extract description from Field + description = field_info.description + + # Create Query annotation with description from model + query_annotation = Query(description=description) if description else Query() + + # Create Annotated type with Query + field_type = field_info.annotation + annotations[field_name] = Annotated[field_type, query_annotation] + + # Set default value from model + if field_info.default is not inspect.Parameter.empty: + defaults[field_name] = field_info.default + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + sig_params = [] + for field_name, field_type in annotations.items(): + default = defaults.get(field_name, inspect.Parameter.empty) + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, + annotation=field_type, + ) + sig_params.append(param) + + dependency_func.__signature__ = inspect.Signature(sig_params) + dependency_func.__annotations__ = annotations + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + + return dependency_func + + +def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for path parameters. + + FastAPI requires path parameters to be explicitly annotated with `Path()`. When using + a Pydantic model that contains path parameters, you typically need a dependency function + that extracts the path parameter and constructs the model. This leads to duplication: + the parameter name, type, and description must be repeated in `Path(...)` annotations + even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts path parameters + from the request and constructs an instance of the Pydantic model. The descriptions are + automatically extracted from the model's Field definitions, making the model the single + source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for. The model should + have exactly one field that represents the path parameter. + + Returns: + A dependency function that can be used with FastAPI's Depends() + + Example: + ```python + get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) + + @router.get("/batches/{batch_id}") + async def retrieve_batch( + request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)] + ): + ... + ``` + """ + # Get the single field from the model (path parameter models typically have one field) + if len(model_class.model_fields) != 1: + raise ValueError( + f"Path parameter model {model_class.__name__} must have exactly one field, " + f"but has {len(model_class.model_fields)} fields" + ) + + field_name, field_info = next(iter(model_class.model_fields.items())) + + # Extract description from Field + description = field_info.description + + # Create Path annotation with description from model + path_annotation = Path(description=description) if description else Path() + + # Create Annotated type with Path + field_type = field_info.annotation + annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]} + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotations[field_name], + ) + + dependency_func.__signature__ = inspect.Signature([param]) + dependency_func.__annotations__ = annotations + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + + return dependency_func