From a6aaf18bb6dc3ed81a375eb0a8eabc445036f447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Nov 2025 14:47:46 +0100 Subject: [PATCH] chore: generate FastAPI dependency functions from Pydantic models to eliminate duplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added create_query_dependency() and create_path_dependency() helpers that automatically generate FastAPI dependency functions from Pydantic models. This makes the models the single source of truth for field types, descriptions, and defaults, eliminating duplication between models.py and fastapi_routes.py. Signed-off-by: Sébastien Han --- src/llama_stack_api/batches/fastapi_routes.py | 32 ++-- src/llama_stack_api/router_utils.py | 150 +++++++++++++++++- 2 files changed, 158 insertions(+), 24 deletions(-) diff --git a/src/llama_stack_api/batches/fastapi_routes.py b/src/llama_stack_api/batches/fastapi_routes.py index b53a4fc03..dd5dc7a6c 100644 --- a/src/llama_stack_api/batches/fastapi_routes.py +++ b/src/llama_stack_api/batches/fastapi_routes.py @@ -13,7 +13,7 @@ all API-related code together. from typing import Annotated -from fastapi import APIRouter, Body, Depends, Path, Query +from fastapi import APIRouter, Body, Depends from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse from llama_stack_api.batches.models import ( @@ -22,32 +22,18 @@ from llama_stack_api.batches.models import ( ListBatchesRequest, RetrieveBatchRequest, ) -from llama_stack_api.router_utils import standard_responses +from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 - -def get_retrieve_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")], -) -> RetrieveBatchRequest: - """Dependency function to create RetrieveBatchRequest from path parameter.""" - return RetrieveBatchRequest(batch_id=batch_id) +# Automatically generate dependency functions from Pydantic models +# This ensures the models are the single source of truth for descriptions +get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) +get_cancel_batch_request = create_path_dependency(CancelBatchRequest) -def get_cancel_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")], -) -> CancelBatchRequest: - """Dependency function to create CancelBatchRequest from path parameter.""" - return CancelBatchRequest(batch_id=batch_id) - - -def get_list_batches_request( - after: Annotated[ - str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") - ] = None, - limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, -) -> ListBatchesRequest: - """Dependency function to create ListBatchesRequest from query parameters.""" - return ListBatchesRequest(after=after, limit=limit) +# Automatically generate dependency function from Pydantic model +# This ensures the model is the single source of truth for descriptions and defaults +get_list_batches_request = create_query_dependency(ListBatchesRequest) def create_router(impl: Batches) -> APIRouter: diff --git a/src/llama_stack_api/router_utils.py b/src/llama_stack_api/router_utils.py index fd0efe060..5d934826c 100644 --- a/src/llama_stack_api/router_utils.py +++ b/src/llama_stack_api/router_utils.py @@ -11,7 +11,12 @@ These responses use OpenAPI $ref references to component responses defined in the OpenAPI specification. """ -from typing import Any +import inspect +from collections.abc import Callable +from typing import Annotated, Any, TypeVar + +from fastapi import Path, Query +from pydantic import BaseModel standard_responses: dict[int | str, dict[str, Any]] = { 400: {"$ref": "#/components/responses/BadRequest400"}, @@ -19,3 +24,146 @@ standard_responses: dict[int | str, dict[str, Any]] = { 500: {"$ref": "#/components/responses/InternalServerError500"}, "default": {"$ref": "#/components/responses/DefaultError"}, } + +T = TypeVar("T", bound=BaseModel) + + +def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for query parameters. + + FastAPI does not natively support using Pydantic models as query parameters + without a dependency function. Using a dependency function typically leads to + duplication: field types, default values, and descriptions must be repeated in + `Query(...)` annotations even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts query parameters + from the request and constructs an instance of the Pydantic model. The descriptions and + defaults are automatically extracted from the model's Field definitions, making the model + the single source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for + + Returns: + A dependency function that can be used with FastAPI's Depends() + + Example: + ```python + get_list_batches_request = create_query_dependency(ListBatchesRequest) + + @router.get("/batches") + async def list_batches( + request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)] + ): + ... + ``` + """ + # Build function signature dynamically from model fields + annotations: dict[str, Any] = {} + defaults: dict[str, Any] = {} + + for field_name, field_info in model_class.model_fields.items(): + # Extract description from Field + description = field_info.description + + # Create Query annotation with description from model + query_annotation = Query(description=description) if description else Query() + + # Create Annotated type with Query + field_type = field_info.annotation + annotations[field_name] = Annotated[field_type, query_annotation] + + # Set default value from model + if field_info.default is not inspect.Parameter.empty: + defaults[field_name] = field_info.default + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + sig_params = [] + for field_name, field_type in annotations.items(): + default = defaults.get(field_name, inspect.Parameter.empty) + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, + annotation=field_type, + ) + sig_params.append(param) + + dependency_func.__signature__ = inspect.Signature(sig_params) + dependency_func.__annotations__ = annotations + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + + return dependency_func + + +def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for path parameters. + + FastAPI requires path parameters to be explicitly annotated with `Path()`. When using + a Pydantic model that contains path parameters, you typically need a dependency function + that extracts the path parameter and constructs the model. This leads to duplication: + the parameter name, type, and description must be repeated in `Path(...)` annotations + even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts path parameters + from the request and constructs an instance of the Pydantic model. The descriptions are + automatically extracted from the model's Field definitions, making the model the single + source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for. The model should + have exactly one field that represents the path parameter. + + Returns: + A dependency function that can be used with FastAPI's Depends() + + Example: + ```python + get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) + + @router.get("/batches/{batch_id}") + async def retrieve_batch( + request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)] + ): + ... + ``` + """ + # Get the single field from the model (path parameter models typically have one field) + if len(model_class.model_fields) != 1: + raise ValueError( + f"Path parameter model {model_class.__name__} must have exactly one field, " + f"but has {len(model_class.model_fields)} fields" + ) + + field_name, field_info = next(iter(model_class.model_fields.items())) + + # Extract description from Field + description = field_info.description + + # Create Path annotation with description from model + path_annotation = Path(description=description) if description else Path() + + # Create Annotated type with Path + field_type = field_info.annotation + annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]} + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotations[field_name], + ) + + dependency_func.__signature__ = inspect.Signature([param]) + dependency_func.__annotations__ = annotations + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + + return dependency_func