From 79f889d3f01079d2bc828c160f2c0aa225d45c78 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 3 Oct 2025 10:35:33 -0700 Subject: [PATCH 1/5] feat(api): add extra_body parameter support with shields example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce ExtraBodyField annotation to enable parameters that arrive via extra_body in client SDKs but are accessible server-side with full typing. These parameters are documented in OpenAPI specs under x-llama-stack-extra-body-params but excluded from generated SDK signatures. Add shields parameter to create_openai_response as the first implementation using this pattern. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- docs/openapi_generator/pyopenapi/generator.py | 23 +++++++++ .../openapi_generator/pyopenapi/operations.py | 42 +++++++++++++++-- .../pyopenapi/specification.py | 10 ++++ docs/openapi_generator/pyopenapi/utility.py | 11 +++++ docs/static/deprecated-llama-stack-spec.html | 37 ++++++++++++++- docs/static/deprecated-llama-stack-spec.yaml | 24 ++++++++++ docs/static/llama-stack-spec.html | 37 ++++++++++++++- docs/static/llama-stack-spec.yaml | 24 ++++++++++ docs/static/stainless-llama-stack-spec.html | 37 ++++++++++++++- docs/static/stainless-llama-stack-spec.yaml | 24 ++++++++++ llama_stack/apis/agents/agents.py | 18 ++++++- llama_stack/schema_utils.py | 47 +++++++++++++++++-- 12 files changed, 321 insertions(+), 13 deletions(-) diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index bb8fa55ab..a8d6aaee9 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -50,6 +50,7 @@ from .specification import ( Document, Example, ExampleRef, + ExtraBodyParameter, MediaType, Operation, Parameter, @@ -677,6 +678,27 @@ class Generator: # parameters passed anywhere parameters = path_parameters + query_parameters + # Build extra body parameters documentation + extra_body_parameters = [] + for param_name, param_type, description in op.extra_body_params: + if is_type_optional(param_type): + inner_type: type = unwrap_optional_type(param_type) + required = False + else: + inner_type = param_type + required = True + + # Use description from ExtraBodyField if available, otherwise from docstring + param_description = description or doc_params.get(param_name) + + extra_body_param = ExtraBodyParameter( + name=param_name, + schema=self.schema_builder.classdef_to_ref(inner_type), + description=param_description, + required=required, + ) + extra_body_parameters.append(extra_body_param) + webmethod = getattr(op.func_ref, "__webmethod__", None) raw_bytes_request_body = False if webmethod: @@ -898,6 +920,7 @@ class Generator: deprecated=getattr(op.webmethod, "deprecated", False) or "DEPRECATED" in op.func_name, security=[] if op.public else None, + extraBodyParameters=extra_body_parameters if extra_body_parameters else None, ) def _get_api_stability_priority(self, api_level: str) -> int: diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index ce33d3bb9..2970d7e53 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -19,10 +19,12 @@ from llama_stack.strong_typing.inspection import get_signature from typing import get_origin, get_args -from fastapi import UploadFile +from fastapi import UploadFile from fastapi.params import File, Form from typing import Annotated +from llama_stack.schema_utils import ExtraBodyField + def split_prefix( s: str, sep: str, prefix: Union[str, Iterable[str]] @@ -89,6 +91,7 @@ class EndpointOperation: :param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs. :param request_params: The parameter that corresponds to the data transmitted in the request body. :param multipart_params: Parameters that indicate multipart/form-data request body. + :param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK. :param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress. :param response_type: The Python type of the data that is transmitted in the response body. :param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT. @@ -106,6 +109,7 @@ class EndpointOperation: query_params: List[OperationParameter] request_params: Optional[OperationParameter] multipart_params: List[OperationParameter] + extra_body_params: List[tuple[str, type, str | None]] event_type: Optional[type] response_type: type http_method: HTTPMethod @@ -265,6 +269,7 @@ def get_endpoint_operations( query_params = [] request_params = [] multipart_params = [] + extra_body_params = [] for param_name, parameter in signature.parameters.items(): param_type = _get_annotation_type(parameter.annotation, func_ref) @@ -279,6 +284,13 @@ def get_endpoint_operations( f"parameter '{param_name}' in function '{func_name}' has no type annotation" ) + # Check if this is an extra_body parameter + is_extra_body, extra_body_desc = _is_extra_body_param(param_type) + if is_extra_body: + # Store in a separate list for documentation + extra_body_params.append((param_name, param_type, extra_body_desc)) + continue # Skip adding to request_params + is_multipart = _is_multipart_param(param_type) if prefix in ["get", "delete"]: @@ -351,6 +363,7 @@ def get_endpoint_operations( query_params=query_params, request_params=request_params, multipart_params=multipart_params, + extra_body_params=extra_body_params, event_type=event_type, response_type=response_type, http_method=http_method, @@ -403,7 +416,7 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]: def _is_multipart_param(param_type: type) -> bool: """ Check if a parameter type indicates multipart form data. - + Returns True if the type is: - UploadFile - Annotated[UploadFile, File()] @@ -413,19 +426,38 @@ def _is_multipart_param(param_type: type) -> bool: """ if param_type is UploadFile: return True - + # Check for Annotated types origin = get_origin(param_type) if origin is None: return False - + if origin is Annotated: args = get_args(param_type) if len(args) < 2: return False - + # Check the annotations for File() or Form() for annotation in args[1:]: if isinstance(annotation, (File, Form)): return True return False + + +def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]: + """ + Check if parameter is marked as coming from extra_body. + + Returns: + (is_extra_body, description): Tuple of boolean and optional description + """ + origin = get_origin(param_type) + if origin is Annotated: + args = get_args(param_type) + for annotation in args[1:]: + if isinstance(annotation, ExtraBodyField): + return True, annotation.description + # Also check by type name for cases where import matters + if type(annotation).__name__ == 'ExtraBodyField': + return True, getattr(annotation, 'description', None) + return False, None diff --git a/docs/openapi_generator/pyopenapi/specification.py b/docs/openapi_generator/pyopenapi/specification.py index d3e5a1f19..90bf54316 100644 --- a/docs/openapi_generator/pyopenapi/specification.py +++ b/docs/openapi_generator/pyopenapi/specification.py @@ -106,6 +106,15 @@ class Parameter: example: Optional[Any] = None +@dataclass +class ExtraBodyParameter: + """Represents a parameter that arrives via extra_body in the request.""" + name: str + schema: SchemaOrRef + description: Optional[str] = None + required: Optional[bool] = None + + @dataclass class Operation: responses: Dict[str, Union[Response, ResponseRef]] @@ -118,6 +127,7 @@ class Operation: callbacks: Optional[Dict[str, "Callback"]] = None security: Optional[List["SecurityRequirement"]] = None deprecated: Optional[bool] = None + extraBodyParameters: Optional[List[ExtraBodyParameter]] = None @dataclass diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index d302b114f..26ef22112 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -52,6 +52,17 @@ class Specification: if display_name: tag["x-displayName"] = display_name + # Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params + paths = json_doc.get("paths", {}) + for path_item in paths.values(): + if isinstance(path_item, dict): + for method in ["get", "post", "put", "delete", "patch"]: + operation = path_item.get(method) + if operation and isinstance(operation, dict): + extra_body_params = operation.pop("extraBodyParameters", None) + if extra_body_params: + operation["x-llama-stack-extra-body-params"] = extra_body_params + return json_doc def get_json_string(self, pretty_print: bool = False) -> str: diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 7edfe3f5d..ffda7552b 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -2132,7 +2132,27 @@ }, "required": true }, - "deprecated": true + "deprecated": true, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/openai/v1/responses/{response_id}": { @@ -9521,6 +9541,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index ca832d46b..0e672f914 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -1559,6 +1559,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: true + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/openai/v1/responses/{response_id}: get: responses: @@ -7076,6 +7088,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 96e97035f..c570dcddf 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -1830,7 +1830,27 @@ }, "required": true }, - "deprecated": false + "deprecated": false, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/responses/{response_id}": { @@ -7616,6 +7636,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index b9e03d614..3e1431b22 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -1411,6 +1411,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: false + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/responses/{response_id}: get: responses: @@ -5739,6 +5751,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 7ec48ef74..167a4aa3c 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -1830,7 +1830,27 @@ }, "required": true }, - "deprecated": false + "deprecated": false, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/responses/{response_id}": { @@ -9625,6 +9645,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 3bede159b..6dc1041f1 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -1414,6 +1414,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: false + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/responses/{response_id}: get: responses: @@ -7184,6 +7196,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 811fe6aa2..8be36b92f 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -28,7 +28,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod from .openai_responses import ( ListOpenAIResponseInputItem, @@ -42,6 +42,20 @@ from .openai_responses import ( ) +@json_schema_type +class ResponseShieldSpec(BaseModel): + """Specification for a shield to apply during response generation. + + :param type: The type/identifier of the shield. + """ + + type: str + # TODO: more fields to be added for shield configuration + + +ResponseShield = str | ResponseShieldSpec + + class Attachment(BaseModel): """An attachment to an agent turn. @@ -805,6 +819,7 @@ class Agents(Protocol): tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API + shields: Annotated[list[ResponseShield] | None, ExtraBodyField("List of shields to apply during response generation. Shields provide safety and content moderation.")] = None, ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. @@ -812,6 +827,7 @@ class Agents(Protocol): :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param include: (Optional) Additional fields to include in the response. + :param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications. :returns: An OpenAIResponseObject. """ ... diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index c58fcdd01..f3c0b5942 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -6,11 +6,50 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any, TypeVar +from typing import Any, Generic, TypeVar from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 +T = TypeVar("T") + + +class ExtraBodyField(Generic[T]): + """ + Marker annotation for parameters that arrive via extra_body in the client SDK. + + These parameters: + - Will NOT appear in the generated client SDK method signature + - WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params + - MUST be passed via the extra_body parameter in client SDK calls + - WILL be available in server-side method signature with proper typing + + Example: + ```python + async def create_openai_response( + self, + input: str, + model: str, + shields: Annotated[list[str] | None, ExtraBodyField("List of shields to apply")] = None, + ) -> ResponseObject: + # shields is available here with proper typing + if shields: + print(f"Using shields: {shields}") + ``` + + Client usage: + ```python + client.responses.create( + input="hello", + model="llama-3", + extra_body={"shields": ["shield-1"]} + ) + ``` + """ + def __init__(self, description: str | None = None): + self.description = description + + @dataclass class WebMethod: level: str | None = None @@ -26,7 +65,7 @@ class WebMethod: deprecated: bool | None = False -T = TypeVar("T", bound=Callable[..., Any]) +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) def webmethod( @@ -40,7 +79,7 @@ def webmethod( descriptive_name: str | None = None, required_scope: str | None = None, deprecated: bool | None = False, -) -> Callable[[T], T]: +) -> Callable[[CallableT], CallableT]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -51,7 +90,7 @@ def webmethod( :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer'). """ - def wrap(func: T) -> T: + def wrap(func: CallableT) -> CallableT: webmethod_obj = WebMethod( route=route, method=method, From 2a54a2433f1cb92d54d4b1147a82ae45209a1060 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 3 Oct 2025 10:38:49 -0700 Subject: [PATCH 2/5] : add integration test --- llama_stack/core/library_client.py | 4 + .../inline/agents/meta_reference/agents.py | 2 + .../responses/openai_responses.py | 5 ++ .../responses/test_extra_body_shields.py | 83 +++++++++++++++++++ 4 files changed, 94 insertions(+) create mode 100644 tests/integration/responses/test_extra_body_shields.py diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index e722e4de6..0d9f9f134 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = options.params or {} body |= options.json_data or {} + # Merge extra_json parameters (extra_body from SDK is converted to extra_json) + if hasattr(options, "extra_json") and options.extra_json: + body |= options.extra_json + matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8bdde86b0..5431e8f28 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents): tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, + shields: list | None = None, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( input, @@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents): tools, include, max_infer_iters, + shields, ) async def list_openai_responses( diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 352be3ded..8ccdcb0e1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -208,10 +208,15 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, + shields: list | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text + # Shields parameter received via extra_body - not yet implemented + if shields is not None: + raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider") + stream_gen = self._create_streaming_response( input=input, model=model, diff --git a/tests/integration/responses/test_extra_body_shields.py b/tests/integration/responses/test_extra_body_shields.py new file mode 100644 index 000000000..b0c6ec39a --- /dev/null +++ b/tests/integration/responses/test_extra_body_shields.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test for extra_body parameter support with shields example. + +This test demonstrates that parameters marked with ExtraBodyField annotation +can be passed via extra_body in the client SDK and are received by the +server-side implementation. +""" + +import pytest +from llama_stack_client import APIStatusError + + +def test_shields_via_extra_body(compat_client, text_model_id): + """Test that shields parameter is received by the server and raises NotImplementedError.""" + + # Test with shields as list of strings (shield IDs) + with pytest.raises((APIStatusError, NotImplementedError)) as exc_info: + compat_client.responses.create( + model=text_model_id, + input="What is the capital of France?", + stream=False, + extra_body={ + "shields": ["test-shield-1", "test-shield-2"] + } + ) + + # Verify the error message indicates shields are not implemented + error_message = str(exc_info.value) + assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower() + + + + +def test_response_without_shields_still_works(compat_client, text_model_id): + """Test that responses still work without shields parameter (backwards compatibility).""" + + response = compat_client.responses.create( + model=text_model_id, + input="Hello, world!", + stream=False, + ) + + # Verify response was created successfully + assert response.id is not None + assert response.output_text is not None + assert len(response.output_text) > 0 + + +def test_shields_parameter_received_end_to_end(compat_client, text_model_id): + """ + Test that shields parameter passed via extra_body reaches the server implementation. + + This verifies end-to-end that: + 1. The parameter can be passed via extra_body in the client SDK + 2. The parameter is properly routed through the API layers + 3. The server-side implementation receives the parameter (verified by NotImplementedError) + + The NotImplementedError proves the extra_body parameter reached the implementation, + as opposed to being rejected earlier due to signature mismatch or validation errors. + """ + # Test with shields parameter via extra_body + with pytest.raises((APIStatusError, NotImplementedError)) as exc_info: + compat_client.responses.create( + model=text_model_id, + input="Test message for shields verification", + stream=False, + extra_body={ + "shields": ["shield-1", "shield-2"] + } + ) + + # The NotImplementedError proves that: + # 1. extra_body.shields was parsed and passed to the API + # 2. The server-side implementation received the shields parameter + # 3. No signature mismatch or validation errors occurred + error_message = str(exc_info.value) + assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower() From cbe7391574c70c5845d89b17bb9c39e4b52e4e68 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 3 Oct 2025 11:11:32 -0700 Subject: [PATCH 3/5] precommit --- docs/docs/providers/agents/index.mdx | 4 ++-- docs/docs/providers/batches/index.mdx | 24 +++++++++---------- docs/docs/providers/inference/index.mdx | 12 +++++----- llama_stack/apis/agents/agents.py | 7 +++++- llama_stack/schema_utils.py | 16 ++++++------- .../responses/test_extra_body_shields.py | 10 ++------ 6 files changed, 35 insertions(+), 38 deletions(-) diff --git a/docs/docs/providers/agents/index.mdx b/docs/docs/providers/agents/index.mdx index 06eb104af..52b92734e 100644 --- a/docs/docs/providers/agents/index.mdx +++ b/docs/docs/providers/agents/index.mdx @@ -1,7 +1,7 @@ --- description: "Agents - APIs for creating and interacting with agentic systems." +APIs for creating and interacting with agentic systems." sidebar_label: Agents title: Agents --- @@ -12,6 +12,6 @@ title: Agents Agents - APIs for creating and interacting with agentic systems. +APIs for creating and interacting with agentic systems. This section contains documentation for all available providers for the **agents** API. diff --git a/docs/docs/providers/batches/index.mdx b/docs/docs/providers/batches/index.mdx index 2c64b277f..18e5e314d 100644 --- a/docs/docs/providers/batches/index.mdx +++ b/docs/docs/providers/batches/index.mdx @@ -1,14 +1,14 @@ --- description: "The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes." +Note: This API is currently under active development and may undergo changes." sidebar_label: Batches title: Batches --- @@ -18,14 +18,14 @@ title: Batches ## Overview The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes. +Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index ebbaf1be1..1dc479675 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -1,9 +1,9 @@ --- description: "Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search." +This API provides the raw interface to the underlying models. Two kinds of models are supported: +- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search." sidebar_label: Inference title: Inference --- @@ -14,8 +14,8 @@ title: Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. +This API provides the raw interface to the underlying models. Two kinds of models are supported: +- LLM models: these models generate "raw" and "chat" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search. This section contains documentation for all available providers for the **inference** API. diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 8be36b92f..cdf47308e 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -819,7 +819,12 @@ class Agents(Protocol): tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API - shields: Annotated[list[ResponseShield] | None, ExtraBodyField("List of shields to apply during response generation. Shields provide safety and content moderation.")] = None, + shields: Annotated[ + list[ResponseShield] | None, + ExtraBodyField( + "List of shields to apply during response generation. Shields provide safety and content moderation." + ), + ] = None, ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index f3c0b5942..8e6c53cc7 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -6,15 +6,12 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Generic, TypeVar +from typing import Any from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 -T = TypeVar("T") - - -class ExtraBodyField(Generic[T]): +class ExtraBodyField[T]: """ Marker annotation for parameters that arrive via extra_body in the client SDK. @@ -30,7 +27,9 @@ class ExtraBodyField(Generic[T]): self, input: str, model: str, - shields: Annotated[list[str] | None, ExtraBodyField("List of shields to apply")] = None, + shields: Annotated[ + list[str] | None, ExtraBodyField("List of shields to apply") + ] = None, ) -> ResponseObject: # shields is available here with proper typing if shields: @@ -40,12 +39,11 @@ class ExtraBodyField(Generic[T]): Client usage: ```python client.responses.create( - input="hello", - model="llama-3", - extra_body={"shields": ["shield-1"]} + input="hello", model="llama-3", extra_body={"shields": ["shield-1"]} ) ``` """ + def __init__(self, description: str | None = None): self.description = description diff --git a/tests/integration/responses/test_extra_body_shields.py b/tests/integration/responses/test_extra_body_shields.py index b0c6ec39a..f20cd24ba 100644 --- a/tests/integration/responses/test_extra_body_shields.py +++ b/tests/integration/responses/test_extra_body_shields.py @@ -25,9 +25,7 @@ def test_shields_via_extra_body(compat_client, text_model_id): model=text_model_id, input="What is the capital of France?", stream=False, - extra_body={ - "shields": ["test-shield-1", "test-shield-2"] - } + extra_body={"shields": ["test-shield-1", "test-shield-2"]}, ) # Verify the error message indicates shields are not implemented @@ -35,8 +33,6 @@ def test_shields_via_extra_body(compat_client, text_model_id): assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower() - - def test_response_without_shields_still_works(compat_client, text_model_id): """Test that responses still work without shields parameter (backwards compatibility).""" @@ -70,9 +66,7 @@ def test_shields_parameter_received_end_to_end(compat_client, text_model_id): model=text_model_id, input="Test message for shields verification", stream=False, - extra_body={ - "shields": ["shield-1", "shield-2"] - } + extra_body={"shields": ["shield-1", "shield-2"]}, ) # The NotImplementedError proves that: From 41d644b5811be1226dcbe3192b0e96c298702c24 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 3 Oct 2025 11:14:09 -0700 Subject: [PATCH 4/5] fix --- llama_stack/schema_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index 8e6c53cc7..c17d6e353 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -6,7 +6,7 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, TypeVar from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 From 0814a3de47edff3b5545d8ca3bbf47363773e6ba Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 3 Oct 2025 11:20:23 -0700 Subject: [PATCH 5/5] fix --- .../responses/test_extra_body_shields.py | 44 ------------------- 1 file changed, 44 deletions(-) diff --git a/tests/integration/responses/test_extra_body_shields.py b/tests/integration/responses/test_extra_body_shields.py index f20cd24ba..3dedb287a 100644 --- a/tests/integration/responses/test_extra_body_shields.py +++ b/tests/integration/responses/test_extra_body_shields.py @@ -31,47 +31,3 @@ def test_shields_via_extra_body(compat_client, text_model_id): # Verify the error message indicates shields are not implemented error_message = str(exc_info.value) assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower() - - -def test_response_without_shields_still_works(compat_client, text_model_id): - """Test that responses still work without shields parameter (backwards compatibility).""" - - response = compat_client.responses.create( - model=text_model_id, - input="Hello, world!", - stream=False, - ) - - # Verify response was created successfully - assert response.id is not None - assert response.output_text is not None - assert len(response.output_text) > 0 - - -def test_shields_parameter_received_end_to_end(compat_client, text_model_id): - """ - Test that shields parameter passed via extra_body reaches the server implementation. - - This verifies end-to-end that: - 1. The parameter can be passed via extra_body in the client SDK - 2. The parameter is properly routed through the API layers - 3. The server-side implementation receives the parameter (verified by NotImplementedError) - - The NotImplementedError proves the extra_body parameter reached the implementation, - as opposed to being rejected earlier due to signature mismatch or validation errors. - """ - # Test with shields parameter via extra_body - with pytest.raises((APIStatusError, NotImplementedError)) as exc_info: - compat_client.responses.create( - model=text_model_id, - input="Test message for shields verification", - stream=False, - extra_body={"shields": ["shield-1", "shield-2"]}, - ) - - # The NotImplementedError proves that: - # 1. extra_body.shields was parsed and passed to the API - # 2. The server-side implementation received the shields parameter - # 3. No signature mismatch or validation errors occurred - error_message = str(exc_info.value) - assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower()