diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index bb8fa55ab..a8d6aaee9 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -50,6 +50,7 @@ from .specification import ( Document, Example, ExampleRef, + ExtraBodyParameter, MediaType, Operation, Parameter, @@ -677,6 +678,27 @@ class Generator: # parameters passed anywhere parameters = path_parameters + query_parameters + # Build extra body parameters documentation + extra_body_parameters = [] + for param_name, param_type, description in op.extra_body_params: + if is_type_optional(param_type): + inner_type: type = unwrap_optional_type(param_type) + required = False + else: + inner_type = param_type + required = True + + # Use description from ExtraBodyField if available, otherwise from docstring + param_description = description or doc_params.get(param_name) + + extra_body_param = ExtraBodyParameter( + name=param_name, + schema=self.schema_builder.classdef_to_ref(inner_type), + description=param_description, + required=required, + ) + extra_body_parameters.append(extra_body_param) + webmethod = getattr(op.func_ref, "__webmethod__", None) raw_bytes_request_body = False if webmethod: @@ -898,6 +920,7 @@ class Generator: deprecated=getattr(op.webmethod, "deprecated", False) or "DEPRECATED" in op.func_name, security=[] if op.public else None, + extraBodyParameters=extra_body_parameters if extra_body_parameters else None, ) def _get_api_stability_priority(self, api_level: str) -> int: diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index ce33d3bb9..2970d7e53 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -19,10 +19,12 @@ from llama_stack.strong_typing.inspection import get_signature from typing import get_origin, get_args -from fastapi import UploadFile +from fastapi import UploadFile from fastapi.params import File, Form from typing import Annotated +from llama_stack.schema_utils import ExtraBodyField + def split_prefix( s: str, sep: str, prefix: Union[str, Iterable[str]] @@ -89,6 +91,7 @@ class EndpointOperation: :param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs. :param request_params: The parameter that corresponds to the data transmitted in the request body. :param multipart_params: Parameters that indicate multipart/form-data request body. + :param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK. :param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress. :param response_type: The Python type of the data that is transmitted in the response body. :param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT. @@ -106,6 +109,7 @@ class EndpointOperation: query_params: List[OperationParameter] request_params: Optional[OperationParameter] multipart_params: List[OperationParameter] + extra_body_params: List[tuple[str, type, str | None]] event_type: Optional[type] response_type: type http_method: HTTPMethod @@ -265,6 +269,7 @@ def get_endpoint_operations( query_params = [] request_params = [] multipart_params = [] + extra_body_params = [] for param_name, parameter in signature.parameters.items(): param_type = _get_annotation_type(parameter.annotation, func_ref) @@ -279,6 +284,13 @@ def get_endpoint_operations( f"parameter '{param_name}' in function '{func_name}' has no type annotation" ) + # Check if this is an extra_body parameter + is_extra_body, extra_body_desc = _is_extra_body_param(param_type) + if is_extra_body: + # Store in a separate list for documentation + extra_body_params.append((param_name, param_type, extra_body_desc)) + continue # Skip adding to request_params + is_multipart = _is_multipart_param(param_type) if prefix in ["get", "delete"]: @@ -351,6 +363,7 @@ def get_endpoint_operations( query_params=query_params, request_params=request_params, multipart_params=multipart_params, + extra_body_params=extra_body_params, event_type=event_type, response_type=response_type, http_method=http_method, @@ -403,7 +416,7 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]: def _is_multipart_param(param_type: type) -> bool: """ Check if a parameter type indicates multipart form data. - + Returns True if the type is: - UploadFile - Annotated[UploadFile, File()] @@ -413,19 +426,38 @@ def _is_multipart_param(param_type: type) -> bool: """ if param_type is UploadFile: return True - + # Check for Annotated types origin = get_origin(param_type) if origin is None: return False - + if origin is Annotated: args = get_args(param_type) if len(args) < 2: return False - + # Check the annotations for File() or Form() for annotation in args[1:]: if isinstance(annotation, (File, Form)): return True return False + + +def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]: + """ + Check if parameter is marked as coming from extra_body. + + Returns: + (is_extra_body, description): Tuple of boolean and optional description + """ + origin = get_origin(param_type) + if origin is Annotated: + args = get_args(param_type) + for annotation in args[1:]: + if isinstance(annotation, ExtraBodyField): + return True, annotation.description + # Also check by type name for cases where import matters + if type(annotation).__name__ == 'ExtraBodyField': + return True, getattr(annotation, 'description', None) + return False, None diff --git a/docs/openapi_generator/pyopenapi/specification.py b/docs/openapi_generator/pyopenapi/specification.py index d3e5a1f19..90bf54316 100644 --- a/docs/openapi_generator/pyopenapi/specification.py +++ b/docs/openapi_generator/pyopenapi/specification.py @@ -106,6 +106,15 @@ class Parameter: example: Optional[Any] = None +@dataclass +class ExtraBodyParameter: + """Represents a parameter that arrives via extra_body in the request.""" + name: str + schema: SchemaOrRef + description: Optional[str] = None + required: Optional[bool] = None + + @dataclass class Operation: responses: Dict[str, Union[Response, ResponseRef]] @@ -118,6 +127,7 @@ class Operation: callbacks: Optional[Dict[str, "Callback"]] = None security: Optional[List["SecurityRequirement"]] = None deprecated: Optional[bool] = None + extraBodyParameters: Optional[List[ExtraBodyParameter]] = None @dataclass diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index d302b114f..26ef22112 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -52,6 +52,17 @@ class Specification: if display_name: tag["x-displayName"] = display_name + # Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params + paths = json_doc.get("paths", {}) + for path_item in paths.values(): + if isinstance(path_item, dict): + for method in ["get", "post", "put", "delete", "patch"]: + operation = path_item.get(method) + if operation and isinstance(operation, dict): + extra_body_params = operation.pop("extraBodyParameters", None) + if extra_body_params: + operation["x-llama-stack-extra-body-params"] = extra_body_params + return json_doc def get_json_string(self, pretty_print: bool = False) -> str: diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 7edfe3f5d..ffda7552b 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -2132,7 +2132,27 @@ }, "required": true }, - "deprecated": true + "deprecated": true, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/openai/v1/responses/{response_id}": { @@ -9521,6 +9541,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index ca832d46b..0e672f914 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -1559,6 +1559,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: true + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/openai/v1/responses/{response_id}: get: responses: @@ -7076,6 +7088,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 96e97035f..c570dcddf 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -1830,7 +1830,27 @@ }, "required": true }, - "deprecated": false + "deprecated": false, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/responses/{response_id}": { @@ -7616,6 +7636,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index b9e03d614..3e1431b22 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -1411,6 +1411,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: false + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/responses/{response_id}: get: responses: @@ -5739,6 +5751,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 7ec48ef74..167a4aa3c 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -1830,7 +1830,27 @@ }, "required": true }, - "deprecated": false + "deprecated": false, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/responses/{response_id}": { @@ -9625,6 +9645,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 3bede159b..6dc1041f1 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -1414,6 +1414,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: false + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/responses/{response_id}: get: responses: @@ -7184,6 +7196,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 811fe6aa2..8be36b92f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -28,7 +28,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod from .openai_responses import ( ListOpenAIResponseInputItem, @@ -42,6 +42,20 @@ from .openai_responses import ( ) +@json_schema_type +class ResponseShieldSpec(BaseModel): + """Specification for a shield to apply during response generation. + + :param type: The type/identifier of the shield. + """ + + type: str + # TODO: more fields to be added for shield configuration + + +ResponseShield = str | ResponseShieldSpec + + class Attachment(BaseModel): """An attachment to an agent turn. @@ -805,6 +819,7 @@ class Agents(Protocol): tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API + shields: Annotated[list[ResponseShield] | None, ExtraBodyField("List of shields to apply during response generation. Shields provide safety and content moderation.")] = None, ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. @@ -812,6 +827,7 @@ class Agents(Protocol): :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param include: (Optional) Additional fields to include in the response. + :param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications. :returns: An OpenAIResponseObject. """ ... diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index c58fcdd01..f3c0b5942 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -6,11 +6,50 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any, TypeVar +from typing import Any, Generic, TypeVar from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 +T = TypeVar("T") + + +class ExtraBodyField(Generic[T]): + """ + Marker annotation for parameters that arrive via extra_body in the client SDK. + + These parameters: + - Will NOT appear in the generated client SDK method signature + - WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params + - MUST be passed via the extra_body parameter in client SDK calls + - WILL be available in server-side method signature with proper typing + + Example: + ```python + async def create_openai_response( + self, + input: str, + model: str, + shields: Annotated[list[str] | None, ExtraBodyField("List of shields to apply")] = None, + ) -> ResponseObject: + # shields is available here with proper typing + if shields: + print(f"Using shields: {shields}") + ``` + + Client usage: + ```python + client.responses.create( + input="hello", + model="llama-3", + extra_body={"shields": ["shield-1"]} + ) + ``` + """ + def __init__(self, description: str | None = None): + self.description = description + + @dataclass class WebMethod: level: str | None = None @@ -26,7 +65,7 @@ class WebMethod: deprecated: bool | None = False -T = TypeVar("T", bound=Callable[..., Any]) +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) def webmethod( @@ -40,7 +79,7 @@ def webmethod( descriptive_name: str | None = None, required_scope: str | None = None, deprecated: bool | None = False, -) -> Callable[[T], T]: +) -> Callable[[CallableT], CallableT]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -51,7 +90,7 @@ def webmethod( :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer'). """ - def wrap(func: T) -> T: + def wrap(func: CallableT) -> CallableT: webmethod_obj = WebMethod( route=route, method=method,