chore: generate FastAPI dependency functions from Pydantic models to eliminate duplication

Added create_query_dependency() and create_path_dependency() helpers
that automatically generate FastAPI dependency functions from Pydantic
models. This makes the models the single source of truth for field
types, descriptions, and defaults, eliminating duplication between
models.py and fastapi_routes.py.

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-24 14:47:46 +01:00
parent 4f08a62fa1
commit a6aaf18bb6
No known key found for this signature in database
2 changed files with 158 additions and 24 deletions

View file

@ -13,7 +13,7 @@ all API-related code together.
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Body, Depends, Path, Query from fastapi import APIRouter, Body, Depends
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
from llama_stack_api.batches.models import ( from llama_stack_api.batches.models import (
@ -22,32 +22,18 @@ from llama_stack_api.batches.models import (
ListBatchesRequest, ListBatchesRequest,
RetrieveBatchRequest, RetrieveBatchRequest,
) )
from llama_stack_api.router_utils import standard_responses from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses
from llama_stack_api.version import LLAMA_STACK_API_V1 from llama_stack_api.version import LLAMA_STACK_API_V1
# Automatically generate dependency functions from Pydantic models
def get_retrieve_batch_request( # This ensures the models are the single source of truth for descriptions
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")], get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest)
) -> RetrieveBatchRequest: get_cancel_batch_request = create_path_dependency(CancelBatchRequest)
"""Dependency function to create RetrieveBatchRequest from path parameter."""
return RetrieveBatchRequest(batch_id=batch_id)
def get_cancel_batch_request( # Automatically generate dependency function from Pydantic model
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")], # This ensures the model is the single source of truth for descriptions and defaults
) -> CancelBatchRequest: get_list_batches_request = create_query_dependency(ListBatchesRequest)
"""Dependency function to create CancelBatchRequest from path parameter."""
return CancelBatchRequest(batch_id=batch_id)
def get_list_batches_request(
after: Annotated[
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
] = None,
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
) -> ListBatchesRequest:
"""Dependency function to create ListBatchesRequest from query parameters."""
return ListBatchesRequest(after=after, limit=limit)
def create_router(impl: Batches) -> APIRouter: def create_router(impl: Batches) -> APIRouter:

View file

@ -11,7 +11,12 @@ These responses use OpenAPI $ref references to component responses defined
in the OpenAPI specification. in the OpenAPI specification.
""" """
from typing import Any import inspect
from collections.abc import Callable
from typing import Annotated, Any, TypeVar
from fastapi import Path, Query
from pydantic import BaseModel
standard_responses: dict[int | str, dict[str, Any]] = { standard_responses: dict[int | str, dict[str, Any]] = {
400: {"$ref": "#/components/responses/BadRequest400"}, 400: {"$ref": "#/components/responses/BadRequest400"},
@ -19,3 +24,146 @@ standard_responses: dict[int | str, dict[str, Any]] = {
500: {"$ref": "#/components/responses/InternalServerError500"}, 500: {"$ref": "#/components/responses/InternalServerError500"},
"default": {"$ref": "#/components/responses/DefaultError"}, "default": {"$ref": "#/components/responses/DefaultError"},
} }
T = TypeVar("T", bound=BaseModel)
def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
"""Create a FastAPI dependency function from a Pydantic model for query parameters.
FastAPI does not natively support using Pydantic models as query parameters
without a dependency function. Using a dependency function typically leads to
duplication: field types, default values, and descriptions must be repeated in
`Query(...)` annotations even though they already exist in the Pydantic model.
This function automatically generates a dependency function that extracts query parameters
from the request and constructs an instance of the Pydantic model. The descriptions and
defaults are automatically extracted from the model's Field definitions, making the model
the single source of truth.
Args:
model_class: The Pydantic model class to create a dependency for
Returns:
A dependency function that can be used with FastAPI's Depends()
Example:
```python
get_list_batches_request = create_query_dependency(ListBatchesRequest)
@router.get("/batches")
async def list_batches(
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)]
):
...
```
"""
# Build function signature dynamically from model fields
annotations: dict[str, Any] = {}
defaults: dict[str, Any] = {}
for field_name, field_info in model_class.model_fields.items():
# Extract description from Field
description = field_info.description
# Create Query annotation with description from model
query_annotation = Query(description=description) if description else Query()
# Create Annotated type with Query
field_type = field_info.annotation
annotations[field_name] = Annotated[field_type, query_annotation]
# Set default value from model
if field_info.default is not inspect.Parameter.empty:
defaults[field_name] = field_info.default
# Create the dependency function dynamically
def dependency_func(**kwargs: Any) -> T:
return model_class(**kwargs)
# Set function signature
sig_params = []
for field_name, field_type in annotations.items():
default = defaults.get(field_name, inspect.Parameter.empty)
param = inspect.Parameter(
field_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=default,
annotation=field_type,
)
sig_params.append(param)
dependency_func.__signature__ = inspect.Signature(sig_params)
dependency_func.__annotations__ = annotations
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request"
return dependency_func
def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
"""Create a FastAPI dependency function from a Pydantic model for path parameters.
FastAPI requires path parameters to be explicitly annotated with `Path()`. When using
a Pydantic model that contains path parameters, you typically need a dependency function
that extracts the path parameter and constructs the model. This leads to duplication:
the parameter name, type, and description must be repeated in `Path(...)` annotations
even though they already exist in the Pydantic model.
This function automatically generates a dependency function that extracts path parameters
from the request and constructs an instance of the Pydantic model. The descriptions are
automatically extracted from the model's Field definitions, making the model the single
source of truth.
Args:
model_class: The Pydantic model class to create a dependency for. The model should
have exactly one field that represents the path parameter.
Returns:
A dependency function that can be used with FastAPI's Depends()
Example:
```python
get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest)
@router.get("/batches/{batch_id}")
async def retrieve_batch(
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)]
):
...
```
"""
# Get the single field from the model (path parameter models typically have one field)
if len(model_class.model_fields) != 1:
raise ValueError(
f"Path parameter model {model_class.__name__} must have exactly one field, "
f"but has {len(model_class.model_fields)} fields"
)
field_name, field_info = next(iter(model_class.model_fields.items()))
# Extract description from Field
description = field_info.description
# Create Path annotation with description from model
path_annotation = Path(description=description) if description else Path()
# Create Annotated type with Path
field_type = field_info.annotation
annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]}
# Create the dependency function dynamically
def dependency_func(**kwargs: Any) -> T:
return model_class(**kwargs)
# Set function signature
param = inspect.Parameter(
field_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=annotations[field_name],
)
dependency_func.__signature__ = inspect.Signature([param])
dependency_func.__annotations__ = annotations
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request"
return dependency_func