chore: re-add x-llama-stack-extra-body-params

Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
Sébastien Han 2025-11-12 11:16:10 +01:00
parent c4cad890cc
commit 66056ddb87
No known key found for this signature in database
6 changed files with 842 additions and 339 deletions

View file

@ -36,6 +36,10 @@ _dynamic_models = []
# Cache for protocol methods to avoid repeated lookups
_protocol_methods_cache: dict[Api, dict[str, Any]] | None = None
# Global dict to store extra body field information by endpoint
# Key: (path, method) tuple, Value: list of (param_name, param_type, description) tuples
_extra_body_fields: dict[tuple[str, str], list[tuple[str, type, str | None]]] = {}
def create_llama_stack_app() -> FastAPI:
"""
@ -238,6 +242,15 @@ def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)
is_post_put = any(method.upper() in ["POST", "PUT", "PATCH"] for method in methods)
# Retrieve and store extra body fields for this endpoint
func = _get_protocol_method(api, name)
extra_body_params = getattr(func, "_extra_body_params", []) if func else []
if extra_body_params:
global _extra_body_fields
for method in methods:
key = (fastapi_path, method.upper())
_extra_body_fields[key] = extra_body_params
if file_form_params and is_post_put:
signature_params = list(file_form_params)
param_annotations = {param.name: param.annotation for param in file_form_params}
@ -402,6 +415,13 @@ def _is_file_or_form_param(param_type: Any) -> bool:
return False
def _is_extra_body_field(metadata_item: Any) -> bool:
"""Check if a metadata item is an ExtraBodyField instance."""
from llama_stack.schema_utils import ExtraBodyField
return isinstance(metadata_item, ExtraBodyField)
def _find_models_for_endpoint(
webmethod, api: Api, method_name: str
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter]]:
@ -433,6 +453,7 @@ def _find_models_for_endpoint(
query_parameters = []
file_form_params = []
path_params = set()
extra_body_params = []
# Extract path parameters from the route
if webmethod and hasattr(webmethod, "route"):
@ -462,6 +483,26 @@ def _find_models_for_endpoint(
file_form_params.append(param)
continue
# Check for ExtraBodyField in Annotated types
is_extra_body = False
extra_body_description = None
if get_origin(param_type) is Annotated:
args = get_args(param_type)
base_type = args[0] if args else param_type
metadata = args[1:] if len(args) > 1 else []
# Check if any metadata item is an ExtraBodyField
for metadata_item in metadata:
if _is_extra_body_field(metadata_item):
is_extra_body = True
extra_body_description = metadata_item.description
break
if is_extra_body:
# Store as extra body parameter - exclude from request model
extra_body_params.append((param_name, base_type, extra_body_description))
continue
# Check if it's a Pydantic model (for POST/PUT requests)
if hasattr(param_type, "model_json_schema"):
# Collect all body parameters including Pydantic models
@ -486,6 +527,12 @@ def _find_models_for_endpoint(
# Also make it safe for FastAPI to avoid forward reference issues
query_parameters.append((param_name, param_type, default_value))
# Store extra body fields for later use in post-processing
# We'll store them when the endpoint is created, as we need the full path
# For now, attach to the function for later retrieval
if extra_body_params:
func._extra_body_params = extra_body_params # type: ignore
# If there's exactly one body parameter and it's a Pydantic model, use it directly
# Otherwise, we'll create a combined request model from all parameters
if len(query_parameters) == 1:
@ -965,6 +1012,100 @@ def _clean_schema_descriptions(openapi_schema: dict[str, Any]) -> dict[str, Any]
return openapi_schema
def _add_extra_body_params_extension(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Add x-llama-stack-extra-body-params extension to requestBody for endpoints with ExtraBodyField parameters.
"""
if "paths" not in openapi_schema:
return openapi_schema
global _extra_body_fields
from pydantic import TypeAdapter
for path, path_item in openapi_schema["paths"].items():
if not isinstance(path_item, dict):
continue
for method in ["get", "post", "put", "delete", "patch", "head", "options"]:
if method not in path_item:
continue
operation = path_item[method]
if not isinstance(operation, dict):
continue
# Check if we have extra body fields for this path/method
key = (path, method.upper())
if key not in _extra_body_fields:
continue
extra_body_params = _extra_body_fields[key]
# Ensure requestBody exists
if "requestBody" not in operation:
continue
request_body = operation["requestBody"]
if not isinstance(request_body, dict):
continue
# Get the schema from requestBody
content = request_body.get("content", {})
json_content = content.get("application/json", {})
schema_ref = json_content.get("schema", {})
# Remove extra body fields from the schema if they exist as properties
# Handle both $ref schemas and inline schemas
if isinstance(schema_ref, dict):
if "$ref" in schema_ref:
# Schema is a reference - remove from the referenced schema
ref_path = schema_ref["$ref"]
if ref_path.startswith("#/components/schemas/"):
schema_name = ref_path.split("/")[-1]
if "components" in openapi_schema and "schemas" in openapi_schema["components"]:
schema_def = openapi_schema["components"]["schemas"].get(schema_name)
if isinstance(schema_def, dict) and "properties" in schema_def:
for param_name, _, _ in extra_body_params:
if param_name in schema_def["properties"]:
del schema_def["properties"][param_name]
# Also remove from required if present
if "required" in schema_def and param_name in schema_def["required"]:
schema_def["required"].remove(param_name)
elif "properties" in schema_ref:
# Schema is inline - remove directly from it
for param_name, _, _ in extra_body_params:
if param_name in schema_ref["properties"]:
del schema_ref["properties"][param_name]
# Also remove from required if present
if "required" in schema_ref and param_name in schema_ref["required"]:
schema_ref["required"].remove(param_name)
# Build the extra body params schema
extra_params_schema = {}
for param_name, param_type, description in extra_body_params:
try:
# Generate JSON schema for the parameter type
adapter = TypeAdapter(param_type)
param_schema = adapter.json_schema(ref_template="#/components/schemas/{model}")
# Add description if provided
if description:
param_schema["description"] = description
extra_params_schema[param_name] = param_schema
except Exception:
# If we can't generate schema, skip this parameter
continue
if extra_params_schema:
# Add the extension to requestBody
if "x-llama-stack-extra-body-params" not in request_body:
request_body["x-llama-stack-extra-body-params"] = extra_params_schema
return openapi_schema
def _remove_query_params_from_body_endpoints(openapi_schema: dict[str, Any]) -> dict[str, Any]:
"""
Remove query parameters from POST/PUT/PATCH endpoints that have a request body.
@ -1399,6 +1540,9 @@ def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
# FastAPI sometimes infers parameters as query params even when they should be in the request body
openapi_schema = _remove_query_params_from_body_endpoints(openapi_schema)
# Add x-llama-stack-extra-body-params extension for ExtraBodyField parameters
openapi_schema = _add_extra_body_params_extension(openapi_schema)
# Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs
# Each spec needs its own deep copy of the full schema to avoid cross-contamination
import copy