feat(api): add extra_body parameter support with shields example

Introduce ExtraBodyField annotation to enable parameters that arrive via extra_body in client SDKs but are accessible server-side with full typing. These parameters are documented in OpenAPI specs under x-llama-stack-extra-body-params but excluded from generated SDK signatures. Add shields parameter to create_openai_response as the first implementation using this pattern.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-03 10:35:33 -07:00
parent ce77c27ff8
commit 79f889d3f0
12 changed files with 321 additions and 13 deletions

View file

@ -50,6 +50,7 @@ from .specification import (
Document,
Example,
ExampleRef,
ExtraBodyParameter,
MediaType,
Operation,
Parameter,
@ -677,6 +678,27 @@ class Generator:
# parameters passed anywhere
parameters = path_parameters + query_parameters
# Build extra body parameters documentation
extra_body_parameters = []
for param_name, param_type, description in op.extra_body_params:
if is_type_optional(param_type):
inner_type: type = unwrap_optional_type(param_type)
required = False
else:
inner_type = param_type
required = True
# Use description from ExtraBodyField if available, otherwise from docstring
param_description = description or doc_params.get(param_name)
extra_body_param = ExtraBodyParameter(
name=param_name,
schema=self.schema_builder.classdef_to_ref(inner_type),
description=param_description,
required=required,
)
extra_body_parameters.append(extra_body_param)
webmethod = getattr(op.func_ref, "__webmethod__", None)
raw_bytes_request_body = False
if webmethod:
@ -898,6 +920,7 @@ class Generator:
deprecated=getattr(op.webmethod, "deprecated", False)
or "DEPRECATED" in op.func_name,
security=[] if op.public else None,
extraBodyParameters=extra_body_parameters if extra_body_parameters else None,
)
def _get_api_stability_priority(self, api_level: str) -> int:

View file

@ -19,10 +19,12 @@ from llama_stack.strong_typing.inspection import get_signature
from typing import get_origin, get_args
from fastapi import UploadFile
from fastapi import UploadFile
from fastapi.params import File, Form
from typing import Annotated
from llama_stack.schema_utils import ExtraBodyField
def split_prefix(
s: str, sep: str, prefix: Union[str, Iterable[str]]
@ -89,6 +91,7 @@ class EndpointOperation:
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
:param request_params: The parameter that corresponds to the data transmitted in the request body.
:param multipart_params: Parameters that indicate multipart/form-data request body.
:param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK.
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
:param response_type: The Python type of the data that is transmitted in the response body.
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
@ -106,6 +109,7 @@ class EndpointOperation:
query_params: List[OperationParameter]
request_params: Optional[OperationParameter]
multipart_params: List[OperationParameter]
extra_body_params: List[tuple[str, type, str | None]]
event_type: Optional[type]
response_type: type
http_method: HTTPMethod
@ -265,6 +269,7 @@ def get_endpoint_operations(
query_params = []
request_params = []
multipart_params = []
extra_body_params = []
for param_name, parameter in signature.parameters.items():
param_type = _get_annotation_type(parameter.annotation, func_ref)
@ -279,6 +284,13 @@ def get_endpoint_operations(
f"parameter '{param_name}' in function '{func_name}' has no type annotation"
)
# Check if this is an extra_body parameter
is_extra_body, extra_body_desc = _is_extra_body_param(param_type)
if is_extra_body:
# Store in a separate list for documentation
extra_body_params.append((param_name, param_type, extra_body_desc))
continue # Skip adding to request_params
is_multipart = _is_multipart_param(param_type)
if prefix in ["get", "delete"]:
@ -351,6 +363,7 @@ def get_endpoint_operations(
query_params=query_params,
request_params=request_params,
multipart_params=multipart_params,
extra_body_params=extra_body_params,
event_type=event_type,
response_type=response_type,
http_method=http_method,
@ -403,7 +416,7 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]:
def _is_multipart_param(param_type: type) -> bool:
"""
Check if a parameter type indicates multipart form data.
Returns True if the type is:
- UploadFile
- Annotated[UploadFile, File()]
@ -413,19 +426,38 @@ def _is_multipart_param(param_type: type) -> bool:
"""
if param_type is UploadFile:
return True
# Check for Annotated types
origin = get_origin(param_type)
if origin is None:
return False
if origin is Annotated:
args = get_args(param_type)
if len(args) < 2:
return False
# Check the annotations for File() or Form()
for annotation in args[1:]:
if isinstance(annotation, (File, Form)):
return True
return False
def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]:
"""
Check if parameter is marked as coming from extra_body.
Returns:
(is_extra_body, description): Tuple of boolean and optional description
"""
origin = get_origin(param_type)
if origin is Annotated:
args = get_args(param_type)
for annotation in args[1:]:
if isinstance(annotation, ExtraBodyField):
return True, annotation.description
# Also check by type name for cases where import matters
if type(annotation).__name__ == 'ExtraBodyField':
return True, getattr(annotation, 'description', None)
return False, None

View file

@ -106,6 +106,15 @@ class Parameter:
example: Optional[Any] = None
@dataclass
class ExtraBodyParameter:
"""Represents a parameter that arrives via extra_body in the request."""
name: str
schema: SchemaOrRef
description: Optional[str] = None
required: Optional[bool] = None
@dataclass
class Operation:
responses: Dict[str, Union[Response, ResponseRef]]
@ -118,6 +127,7 @@ class Operation:
callbacks: Optional[Dict[str, "Callback"]] = None
security: Optional[List["SecurityRequirement"]] = None
deprecated: Optional[bool] = None
extraBodyParameters: Optional[List[ExtraBodyParameter]] = None
@dataclass

View file

@ -52,6 +52,17 @@ class Specification:
if display_name:
tag["x-displayName"] = display_name
# Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params
paths = json_doc.get("paths", {})
for path_item in paths.values():
if isinstance(path_item, dict):
for method in ["get", "post", "put", "delete", "patch"]:
operation = path_item.get(method)
if operation and isinstance(operation, dict):
extra_body_params = operation.pop("extraBodyParameters", None)
if extra_body_params:
operation["x-llama-stack-extra-body-params"] = extra_body_params
return json_doc
def get_json_string(self, pretty_print: bool = False) -> str: