mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
chore: generate FastAPI dependency functions from Pydantic models to eliminate duplication
Added create_query_dependency() and create_path_dependency() helpers that automatically generate FastAPI dependency functions from Pydantic models. This makes the models the single source of truth for field types, descriptions, and defaults, eliminating duplication between models.py and fastapi_routes.py. Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
4f08a62fa1
commit
a6aaf18bb6
2 changed files with 158 additions and 24 deletions
|
|
@ -13,7 +13,7 @@ all API-related code together.
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Body, Depends, Path, Query
|
from fastapi import APIRouter, Body, Depends
|
||||||
|
|
||||||
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
|
from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse
|
||||||
from llama_stack_api.batches.models import (
|
from llama_stack_api.batches.models import (
|
||||||
|
|
@ -22,32 +22,18 @@ from llama_stack_api.batches.models import (
|
||||||
ListBatchesRequest,
|
ListBatchesRequest,
|
||||||
RetrieveBatchRequest,
|
RetrieveBatchRequest,
|
||||||
)
|
)
|
||||||
from llama_stack_api.router_utils import standard_responses
|
from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses
|
||||||
from llama_stack_api.version import LLAMA_STACK_API_V1
|
from llama_stack_api.version import LLAMA_STACK_API_V1
|
||||||
|
|
||||||
|
# Automatically generate dependency functions from Pydantic models
|
||||||
def get_retrieve_batch_request(
|
# This ensures the models are the single source of truth for descriptions
|
||||||
batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")],
|
get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest)
|
||||||
) -> RetrieveBatchRequest:
|
get_cancel_batch_request = create_path_dependency(CancelBatchRequest)
|
||||||
"""Dependency function to create RetrieveBatchRequest from path parameter."""
|
|
||||||
return RetrieveBatchRequest(batch_id=batch_id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_cancel_batch_request(
|
# Automatically generate dependency function from Pydantic model
|
||||||
batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")],
|
# This ensures the model is the single source of truth for descriptions and defaults
|
||||||
) -> CancelBatchRequest:
|
get_list_batches_request = create_query_dependency(ListBatchesRequest)
|
||||||
"""Dependency function to create CancelBatchRequest from path parameter."""
|
|
||||||
return CancelBatchRequest(batch_id=batch_id)
|
|
||||||
|
|
||||||
|
|
||||||
def get_list_batches_request(
|
|
||||||
after: Annotated[
|
|
||||||
str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.")
|
|
||||||
] = None,
|
|
||||||
limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20,
|
|
||||||
) -> ListBatchesRequest:
|
|
||||||
"""Dependency function to create ListBatchesRequest from query parameters."""
|
|
||||||
return ListBatchesRequest(after=after, limit=limit)
|
|
||||||
|
|
||||||
|
|
||||||
def create_router(impl: Batches) -> APIRouter:
|
def create_router(impl: Batches) -> APIRouter:
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,12 @@ These responses use OpenAPI $ref references to component responses defined
|
||||||
in the OpenAPI specification.
|
in the OpenAPI specification.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
import inspect
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Annotated, Any, TypeVar
|
||||||
|
|
||||||
|
from fastapi import Path, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
standard_responses: dict[int | str, dict[str, Any]] = {
|
standard_responses: dict[int | str, dict[str, Any]] = {
|
||||||
400: {"$ref": "#/components/responses/BadRequest400"},
|
400: {"$ref": "#/components/responses/BadRequest400"},
|
||||||
|
|
@ -19,3 +24,146 @@ standard_responses: dict[int | str, dict[str, Any]] = {
|
||||||
500: {"$ref": "#/components/responses/InternalServerError500"},
|
500: {"$ref": "#/components/responses/InternalServerError500"},
|
||||||
"default": {"$ref": "#/components/responses/DefaultError"},
|
"default": {"$ref": "#/components/responses/DefaultError"},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||||
|
"""Create a FastAPI dependency function from a Pydantic model for query parameters.
|
||||||
|
|
||||||
|
FastAPI does not natively support using Pydantic models as query parameters
|
||||||
|
without a dependency function. Using a dependency function typically leads to
|
||||||
|
duplication: field types, default values, and descriptions must be repeated in
|
||||||
|
`Query(...)` annotations even though they already exist in the Pydantic model.
|
||||||
|
|
||||||
|
This function automatically generates a dependency function that extracts query parameters
|
||||||
|
from the request and constructs an instance of the Pydantic model. The descriptions and
|
||||||
|
defaults are automatically extracted from the model's Field definitions, making the model
|
||||||
|
the single source of truth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class: The Pydantic model class to create a dependency for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dependency function that can be used with FastAPI's Depends()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
get_list_batches_request = create_query_dependency(ListBatchesRequest)
|
||||||
|
|
||||||
|
@router.get("/batches")
|
||||||
|
async def list_batches(
|
||||||
|
request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)]
|
||||||
|
):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Build function signature dynamically from model fields
|
||||||
|
annotations: dict[str, Any] = {}
|
||||||
|
defaults: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for field_name, field_info in model_class.model_fields.items():
|
||||||
|
# Extract description from Field
|
||||||
|
description = field_info.description
|
||||||
|
|
||||||
|
# Create Query annotation with description from model
|
||||||
|
query_annotation = Query(description=description) if description else Query()
|
||||||
|
|
||||||
|
# Create Annotated type with Query
|
||||||
|
field_type = field_info.annotation
|
||||||
|
annotations[field_name] = Annotated[field_type, query_annotation]
|
||||||
|
|
||||||
|
# Set default value from model
|
||||||
|
if field_info.default is not inspect.Parameter.empty:
|
||||||
|
defaults[field_name] = field_info.default
|
||||||
|
|
||||||
|
# Create the dependency function dynamically
|
||||||
|
def dependency_func(**kwargs: Any) -> T:
|
||||||
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
# Set function signature
|
||||||
|
sig_params = []
|
||||||
|
for field_name, field_type in annotations.items():
|
||||||
|
default = defaults.get(field_name, inspect.Parameter.empty)
|
||||||
|
param = inspect.Parameter(
|
||||||
|
field_name,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
default=default,
|
||||||
|
annotation=field_type,
|
||||||
|
)
|
||||||
|
sig_params.append(param)
|
||||||
|
|
||||||
|
dependency_func.__signature__ = inspect.Signature(sig_params)
|
||||||
|
dependency_func.__annotations__ = annotations
|
||||||
|
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request"
|
||||||
|
|
||||||
|
return dependency_func
|
||||||
|
|
||||||
|
|
||||||
|
def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]:
|
||||||
|
"""Create a FastAPI dependency function from a Pydantic model for path parameters.
|
||||||
|
|
||||||
|
FastAPI requires path parameters to be explicitly annotated with `Path()`. When using
|
||||||
|
a Pydantic model that contains path parameters, you typically need a dependency function
|
||||||
|
that extracts the path parameter and constructs the model. This leads to duplication:
|
||||||
|
the parameter name, type, and description must be repeated in `Path(...)` annotations
|
||||||
|
even though they already exist in the Pydantic model.
|
||||||
|
|
||||||
|
This function automatically generates a dependency function that extracts path parameters
|
||||||
|
from the request and constructs an instance of the Pydantic model. The descriptions are
|
||||||
|
automatically extracted from the model's Field definitions, making the model the single
|
||||||
|
source of truth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_class: The Pydantic model class to create a dependency for. The model should
|
||||||
|
have exactly one field that represents the path parameter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dependency function that can be used with FastAPI's Depends()
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest)
|
||||||
|
|
||||||
|
@router.get("/batches/{batch_id}")
|
||||||
|
async def retrieve_batch(
|
||||||
|
request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)]
|
||||||
|
):
|
||||||
|
...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Get the single field from the model (path parameter models typically have one field)
|
||||||
|
if len(model_class.model_fields) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Path parameter model {model_class.__name__} must have exactly one field, "
|
||||||
|
f"but has {len(model_class.model_fields)} fields"
|
||||||
|
)
|
||||||
|
|
||||||
|
field_name, field_info = next(iter(model_class.model_fields.items()))
|
||||||
|
|
||||||
|
# Extract description from Field
|
||||||
|
description = field_info.description
|
||||||
|
|
||||||
|
# Create Path annotation with description from model
|
||||||
|
path_annotation = Path(description=description) if description else Path()
|
||||||
|
|
||||||
|
# Create Annotated type with Path
|
||||||
|
field_type = field_info.annotation
|
||||||
|
annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]}
|
||||||
|
|
||||||
|
# Create the dependency function dynamically
|
||||||
|
def dependency_func(**kwargs: Any) -> T:
|
||||||
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
# Set function signature
|
||||||
|
param = inspect.Parameter(
|
||||||
|
field_name,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
annotation=annotations[field_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
dependency_func.__signature__ = inspect.Signature([param])
|
||||||
|
dependency_func.__annotations__ = annotations
|
||||||
|
dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request"
|
||||||
|
|
||||||
|
return dependency_func
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue