mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-04 02:03:44 +00:00
Merge branch 'main' into fix-type-hints-syntax
This commit is contained in:
commit
47027c65a0
12 changed files with 394 additions and 321 deletions
|
|
@ -15,6 +15,7 @@ import typing
|
|||
from typing import Annotated, Any, get_args, get_origin
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.params import Body as FastAPIBody
|
||||
from pydantic import Field, create_model
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
|
|
@ -26,6 +27,8 @@ from .state import _extra_body_fields, register_dynamic_model
|
|||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
type QueryParameter = tuple[str, type, Any, bool]
|
||||
|
||||
|
||||
def _to_pascal_case(segment: str) -> str:
|
||||
tokens = re.findall(r"[A-Za-z]+|\d+", segment)
|
||||
|
|
@ -75,12 +78,12 @@ def _create_endpoint_with_request_model(
|
|||
return endpoint
|
||||
|
||||
|
||||
def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_any: bool = False) -> dict[str, tuple]:
|
||||
def _build_field_definitions(query_parameters: list[QueryParameter], use_any: bool = False) -> dict[str, tuple]:
|
||||
"""Build field definitions for a Pydantic model from query parameters."""
|
||||
from typing import Any
|
||||
|
||||
field_definitions = {}
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
for param_name, param_type, default_value, _ in query_parameters:
|
||||
if use_any:
|
||||
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
|
||||
continue
|
||||
|
|
@ -108,10 +111,10 @@ def _build_field_definitions(query_parameters: list[tuple[str, type, Any]], use_
|
|||
field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
|
||||
|
||||
# Ensure all parameters are included
|
||||
expected_params = {name for name, _, _ in query_parameters}
|
||||
expected_params = {name for name, _, _, _ in query_parameters}
|
||||
missing = expected_params - set(field_definitions.keys())
|
||||
if missing:
|
||||
for param_name, _, default_value in query_parameters:
|
||||
for param_name, _, default_value, _ in query_parameters:
|
||||
if param_name in missing:
|
||||
field_definitions[param_name] = (
|
||||
Any,
|
||||
|
|
@ -126,7 +129,7 @@ def _create_dynamic_request_model(
|
|||
webmethod,
|
||||
method_name: str,
|
||||
http_method: str,
|
||||
query_parameters: list[tuple[str, type, Any]],
|
||||
query_parameters: list[QueryParameter],
|
||||
use_any: bool = False,
|
||||
variant_suffix: str | None = None,
|
||||
) -> type | None:
|
||||
|
|
@ -143,12 +146,12 @@ def _create_dynamic_request_model(
|
|||
|
||||
|
||||
def _build_signature_params(
|
||||
query_parameters: list[tuple[str, type, Any]],
|
||||
query_parameters: list[QueryParameter],
|
||||
) -> tuple[list[inspect.Parameter], dict[str, type]]:
|
||||
"""Build signature parameters and annotations from query parameters."""
|
||||
signature_params = []
|
||||
param_annotations = {}
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
for param_name, param_type, default_value, _ in query_parameters:
|
||||
param_annotations[param_name] = param_type
|
||||
signature_params.append(
|
||||
inspect.Parameter(
|
||||
|
|
@ -219,6 +222,19 @@ def _is_extra_body_field(metadata_item: Any) -> bool:
|
|||
return isinstance(metadata_item, ExtraBodyField)
|
||||
|
||||
|
||||
def _should_embed_parameter(param_type: Any) -> bool:
|
||||
"""Determine whether a parameter should be embedded (wrapped) in the request body."""
|
||||
if get_origin(param_type) is Annotated:
|
||||
args = get_args(param_type)
|
||||
metadata = args[1:] if len(args) > 1 else []
|
||||
for metadata_item in metadata:
|
||||
if isinstance(metadata_item, FastAPIBody):
|
||||
# FastAPI treats embed=None as False, so default to False when unset.
|
||||
return bool(metadata_item.embed)
|
||||
# Unannotated parameters default to embed=True through create_dynamic_typed_route.
|
||||
return True
|
||||
|
||||
|
||||
def _is_async_iterator_type(type_obj: Any) -> bool:
|
||||
"""Check if a type is AsyncIterator or AsyncIterable."""
|
||||
from collections.abc import AsyncIterable, AsyncIterator
|
||||
|
|
@ -282,7 +298,7 @@ def _find_models_for_endpoint(
|
|||
|
||||
Returns:
|
||||
tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
|
||||
where query_parameters is a list of (name, type, default_value) tuples
|
||||
where query_parameters is a list of (name, type, default_value, should_embed) tuples
|
||||
and file_form_params is a list of inspect.Parameter objects for File()/Form() params
|
||||
and streaming_response_model is the model for streaming responses (AsyncIterator content)
|
||||
"""
|
||||
|
|
@ -299,7 +315,7 @@ def _find_models_for_endpoint(
|
|||
|
||||
# Find request model and collect all body parameters
|
||||
request_model = None
|
||||
query_parameters = []
|
||||
query_parameters: list[QueryParameter] = []
|
||||
file_form_params = []
|
||||
path_params = set()
|
||||
extra_body_params = []
|
||||
|
|
@ -325,6 +341,7 @@ def _find_models_for_endpoint(
|
|||
|
||||
# Check if it's a File() or Form() parameter - these need special handling
|
||||
param_type = param.annotation
|
||||
param_should_embed = _should_embed_parameter(param_type)
|
||||
if _is_file_or_form_param(param_type):
|
||||
# File() and Form() parameters must be in the function signature directly
|
||||
# They cannot be part of a Pydantic model
|
||||
|
|
@ -350,30 +367,14 @@ def _find_models_for_endpoint(
|
|||
# Store as extra body parameter - exclude from request model
|
||||
extra_body_params.append((param_name, base_type, extra_body_description))
|
||||
continue
|
||||
param_type = base_type
|
||||
|
||||
# Check if it's a Pydantic model (for POST/PUT requests)
|
||||
if hasattr(param_type, "model_json_schema"):
|
||||
# Collect all body parameters including Pydantic models
|
||||
# We'll decide later whether to use a single model or create a combined one
|
||||
query_parameters.append((param_name, param_type, param.default))
|
||||
elif get_origin(param_type) is Annotated:
|
||||
# Handle Annotated types - get the base type
|
||||
args = get_args(param_type)
|
||||
if args and hasattr(args[0], "model_json_schema"):
|
||||
# Collect Pydantic models from Annotated types
|
||||
query_parameters.append((param_name, args[0], param.default))
|
||||
else:
|
||||
# Regular annotated parameter (but not File/Form, already handled above)
|
||||
query_parameters.append((param_name, param_type, param.default))
|
||||
query_parameters.append((param_name, param_type, param.default, param_should_embed))
|
||||
else:
|
||||
# This is likely a body parameter for POST/PUT or query parameter for GET
|
||||
# Store the parameter info for later use
|
||||
# Preserve inspect.Parameter.empty to distinguish "no default" from "default=None"
|
||||
default_value = param.default
|
||||
|
||||
# Extract the base type from union types (e.g., str | None -> str)
|
||||
# Also make it safe for FastAPI to avoid forward reference issues
|
||||
query_parameters.append((param_name, param_type, default_value))
|
||||
# Regular annotated parameter (but not File/Form, already handled above)
|
||||
query_parameters.append((param_name, param_type, param.default, param_should_embed))
|
||||
|
||||
# Store extra body fields for later use in post-processing
|
||||
# We'll store them when the endpoint is created, as we need the full path
|
||||
|
|
@ -385,8 +386,8 @@ def _find_models_for_endpoint(
|
|||
# Otherwise, we'll create a combined request model from all parameters
|
||||
# BUT: For GET requests, never create a request body - all parameters should be query parameters
|
||||
if is_post_put and len(query_parameters) == 1:
|
||||
param_name, param_type, default_value = query_parameters[0]
|
||||
if hasattr(param_type, "model_json_schema"):
|
||||
param_name, param_type, default_value, should_embed = query_parameters[0]
|
||||
if hasattr(param_type, "model_json_schema") and not should_embed:
|
||||
request_model = param_type
|
||||
query_parameters = [] # Clear query_parameters so we use the single model
|
||||
|
||||
|
|
@ -495,7 +496,7 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
|
|||
if file_form_params and is_post_put:
|
||||
signature_params = list(file_form_params)
|
||||
param_annotations = {param.name: param.annotation for param in file_form_params}
|
||||
for param_name, param_type, default_value in query_parameters:
|
||||
for param_name, param_type, default_value, _ in query_parameters:
|
||||
signature_params.append(
|
||||
inspect.Parameter(
|
||||
param_name,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue