feat(api): add extra_body parameter support with shields example

Introduce ExtraBodyField annotation to enable parameters that arrive via extra_body in client SDKs but are accessible server-side with full typing. These parameters are documented in OpenAPI specs under x-llama-stack-extra-body-params but excluded from generated SDK signatures. Add shields parameter to create_openai_response as the first implementation using this pattern.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Ashwin Bharambe 2025-10-03 10:35:33 -07:00
parent ce77c27ff8
commit 79f889d3f0
12 changed files with 321 additions and 13 deletions

View file

@ -50,6 +50,7 @@ from .specification import (
Document, Document,
Example, Example,
ExampleRef, ExampleRef,
ExtraBodyParameter,
MediaType, MediaType,
Operation, Operation,
Parameter, Parameter,
@ -677,6 +678,27 @@ class Generator:
# parameters passed anywhere # parameters passed anywhere
parameters = path_parameters + query_parameters parameters = path_parameters + query_parameters
# Build extra body parameters documentation
extra_body_parameters = []
for param_name, param_type, description in op.extra_body_params:
if is_type_optional(param_type):
inner_type: type = unwrap_optional_type(param_type)
required = False
else:
inner_type = param_type
required = True
# Use description from ExtraBodyField if available, otherwise from docstring
param_description = description or doc_params.get(param_name)
extra_body_param = ExtraBodyParameter(
name=param_name,
schema=self.schema_builder.classdef_to_ref(inner_type),
description=param_description,
required=required,
)
extra_body_parameters.append(extra_body_param)
webmethod = getattr(op.func_ref, "__webmethod__", None) webmethod = getattr(op.func_ref, "__webmethod__", None)
raw_bytes_request_body = False raw_bytes_request_body = False
if webmethod: if webmethod:
@ -898,6 +920,7 @@ class Generator:
deprecated=getattr(op.webmethod, "deprecated", False) deprecated=getattr(op.webmethod, "deprecated", False)
or "DEPRECATED" in op.func_name, or "DEPRECATED" in op.func_name,
security=[] if op.public else None, security=[] if op.public else None,
extraBodyParameters=extra_body_parameters if extra_body_parameters else None,
) )
def _get_api_stability_priority(self, api_level: str) -> int: def _get_api_stability_priority(self, api_level: str) -> int:

View file

@ -19,10 +19,12 @@ from llama_stack.strong_typing.inspection import get_signature
from typing import get_origin, get_args from typing import get_origin, get_args
from fastapi import UploadFile from fastapi import UploadFile
from fastapi.params import File, Form from fastapi.params import File, Form
from typing import Annotated from typing import Annotated
from llama_stack.schema_utils import ExtraBodyField
def split_prefix( def split_prefix(
s: str, sep: str, prefix: Union[str, Iterable[str]] s: str, sep: str, prefix: Union[str, Iterable[str]]
@ -89,6 +91,7 @@ class EndpointOperation:
:param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs. :param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs.
:param request_params: The parameter that corresponds to the data transmitted in the request body. :param request_params: The parameter that corresponds to the data transmitted in the request body.
:param multipart_params: Parameters that indicate multipart/form-data request body. :param multipart_params: Parameters that indicate multipart/form-data request body.
:param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK.
:param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress. :param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress.
:param response_type: The Python type of the data that is transmitted in the response body. :param response_type: The Python type of the data that is transmitted in the response body.
:param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT. :param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT.
@ -106,6 +109,7 @@ class EndpointOperation:
query_params: List[OperationParameter] query_params: List[OperationParameter]
request_params: Optional[OperationParameter] request_params: Optional[OperationParameter]
multipart_params: List[OperationParameter] multipart_params: List[OperationParameter]
extra_body_params: List[tuple[str, type, str | None]]
event_type: Optional[type] event_type: Optional[type]
response_type: type response_type: type
http_method: HTTPMethod http_method: HTTPMethod
@ -265,6 +269,7 @@ def get_endpoint_operations(
query_params = [] query_params = []
request_params = [] request_params = []
multipart_params = [] multipart_params = []
extra_body_params = []
for param_name, parameter in signature.parameters.items(): for param_name, parameter in signature.parameters.items():
param_type = _get_annotation_type(parameter.annotation, func_ref) param_type = _get_annotation_type(parameter.annotation, func_ref)
@ -279,6 +284,13 @@ def get_endpoint_operations(
f"parameter '{param_name}' in function '{func_name}' has no type annotation" f"parameter '{param_name}' in function '{func_name}' has no type annotation"
) )
# Check if this is an extra_body parameter
is_extra_body, extra_body_desc = _is_extra_body_param(param_type)
if is_extra_body:
# Store in a separate list for documentation
extra_body_params.append((param_name, param_type, extra_body_desc))
continue # Skip adding to request_params
is_multipart = _is_multipart_param(param_type) is_multipart = _is_multipart_param(param_type)
if prefix in ["get", "delete"]: if prefix in ["get", "delete"]:
@ -351,6 +363,7 @@ def get_endpoint_operations(
query_params=query_params, query_params=query_params,
request_params=request_params, request_params=request_params,
multipart_params=multipart_params, multipart_params=multipart_params,
extra_body_params=extra_body_params,
event_type=event_type, event_type=event_type,
response_type=response_type, response_type=response_type,
http_method=http_method, http_method=http_method,
@ -403,7 +416,7 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]:
def _is_multipart_param(param_type: type) -> bool: def _is_multipart_param(param_type: type) -> bool:
""" """
Check if a parameter type indicates multipart form data. Check if a parameter type indicates multipart form data.
Returns True if the type is: Returns True if the type is:
- UploadFile - UploadFile
- Annotated[UploadFile, File()] - Annotated[UploadFile, File()]
@ -413,19 +426,38 @@ def _is_multipart_param(param_type: type) -> bool:
""" """
if param_type is UploadFile: if param_type is UploadFile:
return True return True
# Check for Annotated types # Check for Annotated types
origin = get_origin(param_type) origin = get_origin(param_type)
if origin is None: if origin is None:
return False return False
if origin is Annotated: if origin is Annotated:
args = get_args(param_type) args = get_args(param_type)
if len(args) < 2: if len(args) < 2:
return False return False
# Check the annotations for File() or Form() # Check the annotations for File() or Form()
for annotation in args[1:]: for annotation in args[1:]:
if isinstance(annotation, (File, Form)): if isinstance(annotation, (File, Form)):
return True return True
return False return False
def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]:
"""
Check if parameter is marked as coming from extra_body.
Returns:
(is_extra_body, description): Tuple of boolean and optional description
"""
origin = get_origin(param_type)
if origin is Annotated:
args = get_args(param_type)
for annotation in args[1:]:
if isinstance(annotation, ExtraBodyField):
return True, annotation.description
# Also check by type name for cases where import matters
if type(annotation).__name__ == 'ExtraBodyField':
return True, getattr(annotation, 'description', None)
return False, None

View file

@ -106,6 +106,15 @@ class Parameter:
example: Optional[Any] = None example: Optional[Any] = None
@dataclass
class ExtraBodyParameter:
"""Represents a parameter that arrives via extra_body in the request."""
name: str
schema: SchemaOrRef
description: Optional[str] = None
required: Optional[bool] = None
@dataclass @dataclass
class Operation: class Operation:
responses: Dict[str, Union[Response, ResponseRef]] responses: Dict[str, Union[Response, ResponseRef]]
@ -118,6 +127,7 @@ class Operation:
callbacks: Optional[Dict[str, "Callback"]] = None callbacks: Optional[Dict[str, "Callback"]] = None
security: Optional[List["SecurityRequirement"]] = None security: Optional[List["SecurityRequirement"]] = None
deprecated: Optional[bool] = None deprecated: Optional[bool] = None
extraBodyParameters: Optional[List[ExtraBodyParameter]] = None
@dataclass @dataclass

View file

@ -52,6 +52,17 @@ class Specification:
if display_name: if display_name:
tag["x-displayName"] = display_name tag["x-displayName"] = display_name
# Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params
paths = json_doc.get("paths", {})
for path_item in paths.values():
if isinstance(path_item, dict):
for method in ["get", "post", "put", "delete", "patch"]:
operation = path_item.get(method)
if operation and isinstance(operation, dict):
extra_body_params = operation.pop("extraBodyParameters", None)
if extra_body_params:
operation["x-llama-stack-extra-body-params"] = extra_body_params
return json_doc return json_doc
def get_json_string(self, pretty_print: bool = False) -> str: def get_json_string(self, pretty_print: bool = False) -> str:

View file

@ -2132,7 +2132,27 @@
}, },
"required": true "required": true
}, },
"deprecated": true "deprecated": true,
"x-llama-stack-extra-body-params": [
{
"name": "shields",
"schema": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ResponseShieldSpec"
}
]
}
},
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
"required": false
}
]
} }
}, },
"/v1/openai/v1/responses/{response_id}": { "/v1/openai/v1/responses/{response_id}": {
@ -9521,6 +9541,21 @@
"title": "OpenAIResponseText", "title": "OpenAIResponseText",
"description": "Text response configuration for OpenAI responses." "description": "Text response configuration for OpenAI responses."
}, },
"ResponseShieldSpec": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "The type/identifier of the shield."
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ResponseShieldSpec",
"description": "Specification for a shield to apply during response generation."
},
"OpenAIResponseInputTool": { "OpenAIResponseInputTool": {
"oneOf": [ "oneOf": [
{ {

View file

@ -1559,6 +1559,18 @@ paths:
$ref: '#/components/schemas/CreateOpenaiResponseRequest' $ref: '#/components/schemas/CreateOpenaiResponseRequest'
required: true required: true
deprecated: true deprecated: true
x-llama-stack-extra-body-params:
- name: shields
schema:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/ResponseShieldSpec'
description: >-
List of shields to apply during response generation. Shields provide safety
and content moderation.
required: false
/v1/openai/v1/responses/{response_id}: /v1/openai/v1/responses/{response_id}:
get: get:
responses: responses:
@ -7076,6 +7088,18 @@ components:
title: OpenAIResponseText title: OpenAIResponseText
description: >- description: >-
Text response configuration for OpenAI responses. Text response configuration for OpenAI responses.
ResponseShieldSpec:
type: object
properties:
type:
type: string
description: The type/identifier of the shield.
additionalProperties: false
required:
- type
title: ResponseShieldSpec
description: >-
Specification for a shield to apply during response generation.
OpenAIResponseInputTool: OpenAIResponseInputTool:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'

View file

@ -1830,7 +1830,27 @@
}, },
"required": true "required": true
}, },
"deprecated": false "deprecated": false,
"x-llama-stack-extra-body-params": [
{
"name": "shields",
"schema": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ResponseShieldSpec"
}
]
}
},
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
"required": false
}
]
} }
}, },
"/v1/responses/{response_id}": { "/v1/responses/{response_id}": {
@ -7616,6 +7636,21 @@
"title": "OpenAIResponseText", "title": "OpenAIResponseText",
"description": "Text response configuration for OpenAI responses." "description": "Text response configuration for OpenAI responses."
}, },
"ResponseShieldSpec": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "The type/identifier of the shield."
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ResponseShieldSpec",
"description": "Specification for a shield to apply during response generation."
},
"OpenAIResponseInputTool": { "OpenAIResponseInputTool": {
"oneOf": [ "oneOf": [
{ {

View file

@ -1411,6 +1411,18 @@ paths:
$ref: '#/components/schemas/CreateOpenaiResponseRequest' $ref: '#/components/schemas/CreateOpenaiResponseRequest'
required: true required: true
deprecated: false deprecated: false
x-llama-stack-extra-body-params:
- name: shields
schema:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/ResponseShieldSpec'
description: >-
List of shields to apply during response generation. Shields provide safety
and content moderation.
required: false
/v1/responses/{response_id}: /v1/responses/{response_id}:
get: get:
responses: responses:
@ -5739,6 +5751,18 @@ components:
title: OpenAIResponseText title: OpenAIResponseText
description: >- description: >-
Text response configuration for OpenAI responses. Text response configuration for OpenAI responses.
ResponseShieldSpec:
type: object
properties:
type:
type: string
description: The type/identifier of the shield.
additionalProperties: false
required:
- type
title: ResponseShieldSpec
description: >-
Specification for a shield to apply during response generation.
OpenAIResponseInputTool: OpenAIResponseInputTool:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'

View file

@ -1830,7 +1830,27 @@
}, },
"required": true "required": true
}, },
"deprecated": false "deprecated": false,
"x-llama-stack-extra-body-params": [
{
"name": "shields",
"schema": {
"type": "array",
"items": {
"oneOf": [
{
"type": "string"
},
{
"$ref": "#/components/schemas/ResponseShieldSpec"
}
]
}
},
"description": "List of shields to apply during response generation. Shields provide safety and content moderation.",
"required": false
}
]
} }
}, },
"/v1/responses/{response_id}": { "/v1/responses/{response_id}": {
@ -9625,6 +9645,21 @@
"title": "OpenAIResponseText", "title": "OpenAIResponseText",
"description": "Text response configuration for OpenAI responses." "description": "Text response configuration for OpenAI responses."
}, },
"ResponseShieldSpec": {
"type": "object",
"properties": {
"type": {
"type": "string",
"description": "The type/identifier of the shield."
}
},
"additionalProperties": false,
"required": [
"type"
],
"title": "ResponseShieldSpec",
"description": "Specification for a shield to apply during response generation."
},
"OpenAIResponseInputTool": { "OpenAIResponseInputTool": {
"oneOf": [ "oneOf": [
{ {

View file

@ -1414,6 +1414,18 @@ paths:
$ref: '#/components/schemas/CreateOpenaiResponseRequest' $ref: '#/components/schemas/CreateOpenaiResponseRequest'
required: true required: true
deprecated: false deprecated: false
x-llama-stack-extra-body-params:
- name: shields
schema:
type: array
items:
oneOf:
- type: string
- $ref: '#/components/schemas/ResponseShieldSpec'
description: >-
List of shields to apply during response generation. Shields provide safety
and content moderation.
required: false
/v1/responses/{response_id}: /v1/responses/{response_id}:
get: get:
responses: responses:
@ -7184,6 +7196,18 @@ components:
title: OpenAIResponseText title: OpenAIResponseText
description: >- description: >-
Text response configuration for OpenAI responses. Text response configuration for OpenAI responses.
ResponseShieldSpec:
type: object
properties:
type:
type: string
description: The type/identifier of the shield.
additionalProperties: false
required:
- type
title: ResponseShieldSpec
description: >-
Specification for a shield to apply during response generation.
OpenAIResponseInputTool: OpenAIResponseInputTool:
oneOf: oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch'

View file

@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
from .openai_responses import ( from .openai_responses import (
ListOpenAIResponseInputItem, ListOpenAIResponseInputItem,
@ -42,6 +42,20 @@ from .openai_responses import (
) )
@json_schema_type
class ResponseShieldSpec(BaseModel):
"""Specification for a shield to apply during response generation.
:param type: The type/identifier of the shield.
"""
type: str
# TODO: more fields to be added for shield configuration
ResponseShield = str | ResponseShieldSpec
class Attachment(BaseModel): class Attachment(BaseModel):
"""An attachment to an agent turn. """An attachment to an agent turn.
@ -805,6 +819,7 @@ class Agents(Protocol):
tools: list[OpenAIResponseInputTool] | None = None, tools: list[OpenAIResponseInputTool] | None = None,
include: list[str] | None = None, include: list[str] | None = None,
max_infer_iters: int | None = 10, # this is an extension to the OpenAI API max_infer_iters: int | None = 10, # this is an extension to the OpenAI API
shields: Annotated[list[ResponseShield] | None, ExtraBodyField("List of shields to apply during response generation. Shields provide safety and content moderation.")] = None,
) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]:
"""Create a new OpenAI response. """Create a new OpenAI response.
@ -812,6 +827,7 @@ class Agents(Protocol):
:param model: The underlying LLM used for completions. :param model: The underlying LLM used for completions.
:param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses.
:param include: (Optional) Additional fields to include in the response. :param include: (Optional) Additional fields to include in the response.
:param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications.
:returns: An OpenAIResponseObject. :returns: An OpenAIResponseObject.
""" """
... ...

View file

@ -6,11 +6,50 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, TypeVar from typing import Any, Generic, TypeVar
from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 from .strong_typing.schema import json_schema_type, register_schema # noqa: F401
T = TypeVar("T")
class ExtraBodyField(Generic[T]):
"""
Marker annotation for parameters that arrive via extra_body in the client SDK.
These parameters:
- Will NOT appear in the generated client SDK method signature
- WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params
- MUST be passed via the extra_body parameter in client SDK calls
- WILL be available in server-side method signature with proper typing
Example:
```python
async def create_openai_response(
self,
input: str,
model: str,
shields: Annotated[list[str] | None, ExtraBodyField("List of shields to apply")] = None,
) -> ResponseObject:
# shields is available here with proper typing
if shields:
print(f"Using shields: {shields}")
```
Client usage:
```python
client.responses.create(
input="hello",
model="llama-3",
extra_body={"shields": ["shield-1"]}
)
```
"""
def __init__(self, description: str | None = None):
self.description = description
@dataclass @dataclass
class WebMethod: class WebMethod:
level: str | None = None level: str | None = None
@ -26,7 +65,7 @@ class WebMethod:
deprecated: bool | None = False deprecated: bool | None = False
T = TypeVar("T", bound=Callable[..., Any]) CallableT = TypeVar("CallableT", bound=Callable[..., Any])
def webmethod( def webmethod(
@ -40,7 +79,7 @@ def webmethod(
descriptive_name: str | None = None, descriptive_name: str | None = None,
required_scope: str | None = None, required_scope: str | None = None,
deprecated: bool | None = False, deprecated: bool | None = False,
) -> Callable[[T], T]: ) -> Callable[[CallableT], CallableT]:
""" """
Decorator that supplies additional metadata to an endpoint operation function. Decorator that supplies additional metadata to an endpoint operation function.
@ -51,7 +90,7 @@ def webmethod(
:param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer'). :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer').
""" """
def wrap(func: T) -> T: def wrap(func: CallableT) -> CallableT:
webmethod_obj = WebMethod( webmethod_obj = WebMethod(
route=route, route=route,
method=method, method=method,