diff --git a/docs/docs/providers/agents/index.mdx b/docs/docs/providers/agents/index.mdx index 06eb104af..52b92734e 100644 --- a/docs/docs/providers/agents/index.mdx +++ b/docs/docs/providers/agents/index.mdx @@ -1,7 +1,7 @@ --- description: "Agents - APIs for creating and interacting with agentic systems." +APIs for creating and interacting with agentic systems." sidebar_label: Agents title: Agents --- @@ -12,6 +12,6 @@ title: Agents Agents - APIs for creating and interacting with agentic systems. +APIs for creating and interacting with agentic systems. This section contains documentation for all available providers for the **agents** API. diff --git a/docs/docs/providers/batches/index.mdx b/docs/docs/providers/batches/index.mdx index 2c64b277f..18e5e314d 100644 --- a/docs/docs/providers/batches/index.mdx +++ b/docs/docs/providers/batches/index.mdx @@ -1,14 +1,14 @@ --- description: "The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes." +Note: This API is currently under active development and may undergo changes." sidebar_label: Batches title: Batches --- @@ -18,14 +18,14 @@ title: Batches ## Overview The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes. +Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index ebbaf1be1..1dc479675 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -1,9 +1,9 @@ --- description: "Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search." +This API provides the raw interface to the underlying models. Two kinds of models are supported: +- LLM models: these models generate \"raw\" and \"chat\" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search." sidebar_label: Inference title: Inference --- @@ -14,8 +14,8 @@ title: Inference Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Two kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. +This API provides the raw interface to the underlying models. Two kinds of models are supported: +- LLM models: these models generate "raw" and "chat" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search. This section contains documentation for all available providers for the **inference** API. diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index bb8fa55ab..a8d6aaee9 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -50,6 +50,7 @@ from .specification import ( Document, Example, ExampleRef, + ExtraBodyParameter, MediaType, Operation, Parameter, @@ -677,6 +678,27 @@ class Generator: # parameters passed anywhere parameters = path_parameters + query_parameters + # Build extra body parameters documentation + extra_body_parameters = [] + for param_name, param_type, description in op.extra_body_params: + if is_type_optional(param_type): + inner_type: type = unwrap_optional_type(param_type) + required = False + else: + inner_type = param_type + required = True + + # Use description from ExtraBodyField if available, otherwise from docstring + param_description = description or doc_params.get(param_name) + + extra_body_param = ExtraBodyParameter( + name=param_name, + schema=self.schema_builder.classdef_to_ref(inner_type), + description=param_description, + required=required, + ) + extra_body_parameters.append(extra_body_param) + webmethod = getattr(op.func_ref, "__webmethod__", None) raw_bytes_request_body = False if webmethod: @@ -898,6 +920,7 @@ class Generator: deprecated=getattr(op.webmethod, "deprecated", False) or "DEPRECATED" in op.func_name, security=[] if op.public else None, + extraBodyParameters=extra_body_parameters if extra_body_parameters else None, ) def _get_api_stability_priority(self, api_level: str) -> int: diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index ce33d3bb9..2970d7e53 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -19,10 +19,12 @@ from llama_stack.strong_typing.inspection import get_signature from typing import get_origin, get_args -from fastapi import UploadFile +from fastapi import UploadFile from fastapi.params import File, Form from typing import Annotated +from llama_stack.schema_utils import ExtraBodyField + def split_prefix( s: str, sep: str, prefix: Union[str, Iterable[str]] @@ -89,6 +91,7 @@ class EndpointOperation: :param query_params: Parameters of the operation signature that are passed in the query string as `key=value` pairs. :param request_params: The parameter that corresponds to the data transmitted in the request body. :param multipart_params: Parameters that indicate multipart/form-data request body. + :param extra_body_params: Parameters that arrive via extra_body and are documented but not in SDK. :param event_type: The Python type of the data that is transmitted out-of-band (e.g. via websockets) while the operation is in progress. :param response_type: The Python type of the data that is transmitted in the response body. :param http_method: The HTTP method used to invoke the endpoint such as POST, GET or PUT. @@ -106,6 +109,7 @@ class EndpointOperation: query_params: List[OperationParameter] request_params: Optional[OperationParameter] multipart_params: List[OperationParameter] + extra_body_params: List[tuple[str, type, str | None]] event_type: Optional[type] response_type: type http_method: HTTPMethod @@ -265,6 +269,7 @@ def get_endpoint_operations( query_params = [] request_params = [] multipart_params = [] + extra_body_params = [] for param_name, parameter in signature.parameters.items(): param_type = _get_annotation_type(parameter.annotation, func_ref) @@ -279,6 +284,13 @@ def get_endpoint_operations( f"parameter '{param_name}' in function '{func_name}' has no type annotation" ) + # Check if this is an extra_body parameter + is_extra_body, extra_body_desc = _is_extra_body_param(param_type) + if is_extra_body: + # Store in a separate list for documentation + extra_body_params.append((param_name, param_type, extra_body_desc)) + continue # Skip adding to request_params + is_multipart = _is_multipart_param(param_type) if prefix in ["get", "delete"]: @@ -351,6 +363,7 @@ def get_endpoint_operations( query_params=query_params, request_params=request_params, multipart_params=multipart_params, + extra_body_params=extra_body_params, event_type=event_type, response_type=response_type, http_method=http_method, @@ -403,7 +416,7 @@ def get_endpoint_events(endpoint: type) -> Dict[str, type]: def _is_multipart_param(param_type: type) -> bool: """ Check if a parameter type indicates multipart form data. - + Returns True if the type is: - UploadFile - Annotated[UploadFile, File()] @@ -413,19 +426,38 @@ def _is_multipart_param(param_type: type) -> bool: """ if param_type is UploadFile: return True - + # Check for Annotated types origin = get_origin(param_type) if origin is None: return False - + if origin is Annotated: args = get_args(param_type) if len(args) < 2: return False - + # Check the annotations for File() or Form() for annotation in args[1:]: if isinstance(annotation, (File, Form)): return True return False + + +def _is_extra_body_param(param_type: type) -> tuple[bool, str | None]: + """ + Check if parameter is marked as coming from extra_body. + + Returns: + (is_extra_body, description): Tuple of boolean and optional description + """ + origin = get_origin(param_type) + if origin is Annotated: + args = get_args(param_type) + for annotation in args[1:]: + if isinstance(annotation, ExtraBodyField): + return True, annotation.description + # Also check by type name for cases where import matters + if type(annotation).__name__ == 'ExtraBodyField': + return True, getattr(annotation, 'description', None) + return False, None diff --git a/docs/openapi_generator/pyopenapi/specification.py b/docs/openapi_generator/pyopenapi/specification.py index d3e5a1f19..90bf54316 100644 --- a/docs/openapi_generator/pyopenapi/specification.py +++ b/docs/openapi_generator/pyopenapi/specification.py @@ -106,6 +106,15 @@ class Parameter: example: Optional[Any] = None +@dataclass +class ExtraBodyParameter: + """Represents a parameter that arrives via extra_body in the request.""" + name: str + schema: SchemaOrRef + description: Optional[str] = None + required: Optional[bool] = None + + @dataclass class Operation: responses: Dict[str, Union[Response, ResponseRef]] @@ -118,6 +127,7 @@ class Operation: callbacks: Optional[Dict[str, "Callback"]] = None security: Optional[List["SecurityRequirement"]] = None deprecated: Optional[bool] = None + extraBodyParameters: Optional[List[ExtraBodyParameter]] = None @dataclass diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index d302b114f..26ef22112 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -52,6 +52,17 @@ class Specification: if display_name: tag["x-displayName"] = display_name + # Handle operations to rename extraBodyParameters -> x-llama-stack-extra-body-params + paths = json_doc.get("paths", {}) + for path_item in paths.values(): + if isinstance(path_item, dict): + for method in ["get", "post", "put", "delete", "patch"]: + operation = path_item.get(method) + if operation and isinstance(operation, dict): + extra_body_params = operation.pop("extraBodyParameters", None) + if extra_body_params: + operation["x-llama-stack-extra-body-params"] = extra_body_params + return json_doc def get_json_string(self, pretty_print: bool = False) -> str: diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 7edfe3f5d..ffda7552b 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -2132,7 +2132,27 @@ }, "required": true }, - "deprecated": true + "deprecated": true, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/openai/v1/responses/{response_id}": { @@ -9521,6 +9541,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index ca832d46b..0e672f914 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -1559,6 +1559,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: true + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/openai/v1/responses/{response_id}: get: responses: @@ -7076,6 +7088,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 96e97035f..c570dcddf 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -1830,7 +1830,27 @@ }, "required": true }, - "deprecated": false + "deprecated": false, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/responses/{response_id}": { @@ -7616,6 +7636,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index b9e03d614..3e1431b22 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -1411,6 +1411,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: false + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/responses/{response_id}: get: responses: @@ -5739,6 +5751,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 7ec48ef74..167a4aa3c 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -1830,7 +1830,27 @@ }, "required": true }, - "deprecated": false + "deprecated": false, + "x-llama-stack-extra-body-params": [ + { + "name": "shields", + "schema": { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "$ref": "#/components/schemas/ResponseShieldSpec" + } + ] + } + }, + "description": "List of shields to apply during response generation. Shields provide safety and content moderation.", + "required": false + } + ] } }, "/v1/responses/{response_id}": { @@ -9625,6 +9645,21 @@ "title": "OpenAIResponseText", "description": "Text response configuration for OpenAI responses." }, + "ResponseShieldSpec": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type/identifier of the shield." + } + }, + "additionalProperties": false, + "required": [ + "type" + ], + "title": "ResponseShieldSpec", + "description": "Specification for a shield to apply during response generation." + }, "OpenAIResponseInputTool": { "oneOf": [ { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 3bede159b..6dc1041f1 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -1414,6 +1414,18 @@ paths: $ref: '#/components/schemas/CreateOpenaiResponseRequest' required: true deprecated: false + x-llama-stack-extra-body-params: + - name: shields + schema: + type: array + items: + oneOf: + - type: string + - $ref: '#/components/schemas/ResponseShieldSpec' + description: >- + List of shields to apply during response generation. Shields provide safety + and content moderation. + required: false /v1/responses/{response_id}: get: responses: @@ -7184,6 +7196,18 @@ components: title: OpenAIResponseText description: >- Text response configuration for OpenAI responses. + ResponseShieldSpec: + type: object + properties: + type: + type: string + description: The type/identifier of the shield. + additionalProperties: false + required: + - type + title: ResponseShieldSpec + description: >- + Specification for a shield to apply during response generation. OpenAIResponseInputTool: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputToolWebSearch' diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 811fe6aa2..cdf47308e 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -28,7 +28,7 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA -from llama_stack.schema_utils import json_schema_type, register_schema, webmethod +from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod from .openai_responses import ( ListOpenAIResponseInputItem, @@ -42,6 +42,20 @@ from .openai_responses import ( ) +@json_schema_type +class ResponseShieldSpec(BaseModel): + """Specification for a shield to apply during response generation. + + :param type: The type/identifier of the shield. + """ + + type: str + # TODO: more fields to be added for shield configuration + + +ResponseShield = str | ResponseShieldSpec + + class Attachment(BaseModel): """An attachment to an agent turn. @@ -805,6 +819,12 @@ class Agents(Protocol): tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, # this is an extension to the OpenAI API + shields: Annotated[ + list[ResponseShield] | None, + ExtraBodyField( + "List of shields to apply during response generation. Shields provide safety and content moderation." + ), + ] = None, ) -> OpenAIResponseObject | AsyncIterator[OpenAIResponseObjectStream]: """Create a new OpenAI response. @@ -812,6 +832,7 @@ class Agents(Protocol): :param model: The underlying LLM used for completions. :param previous_response_id: (Optional) if specified, the new response will be a continuation of the previous response. This can be used to easily fork-off new responses from existing responses. :param include: (Optional) Additional fields to include in the response. + :param shields: (Optional) List of shields to apply during response generation. Can be shield IDs (strings) or shield specifications. :returns: An OpenAIResponseObject. """ ... diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index e722e4de6..0d9f9f134 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = options.params or {} body |= options.json_data or {} + # Merge extra_json parameters (extra_body from SDK is converted to extra_json) + if hasattr(options, "extra_json") and options.extra_json: + body |= options.extra_json + matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8bdde86b0..5431e8f28 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents): tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, + shields: list | None = None, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( input, @@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents): tools, include, max_infer_iters, + shields, ) async def list_openai_responses( diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 352be3ded..8ccdcb0e1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -208,10 +208,15 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, + shields: list | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text + # Shields parameter received via extra_body - not yet implemented + if shields is not None: + raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider") + stream_gen = self._create_streaming_response( input=input, model=model, diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index c58fcdd01..c17d6e353 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -11,6 +11,43 @@ from typing import Any, TypeVar from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 +class ExtraBodyField[T]: + """ + Marker annotation for parameters that arrive via extra_body in the client SDK. + + These parameters: + - Will NOT appear in the generated client SDK method signature + - WILL be documented in OpenAPI spec under x-llama-stack-extra-body-params + - MUST be passed via the extra_body parameter in client SDK calls + - WILL be available in server-side method signature with proper typing + + Example: + ```python + async def create_openai_response( + self, + input: str, + model: str, + shields: Annotated[ + list[str] | None, ExtraBodyField("List of shields to apply") + ] = None, + ) -> ResponseObject: + # shields is available here with proper typing + if shields: + print(f"Using shields: {shields}") + ``` + + Client usage: + ```python + client.responses.create( + input="hello", model="llama-3", extra_body={"shields": ["shield-1"]} + ) + ``` + """ + + def __init__(self, description: str | None = None): + self.description = description + + @dataclass class WebMethod: level: str | None = None @@ -26,7 +63,7 @@ class WebMethod: deprecated: bool | None = False -T = TypeVar("T", bound=Callable[..., Any]) +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) def webmethod( @@ -40,7 +77,7 @@ def webmethod( descriptive_name: str | None = None, required_scope: str | None = None, deprecated: bool | None = False, -) -> Callable[[T], T]: +) -> Callable[[CallableT], CallableT]: """ Decorator that supplies additional metadata to an endpoint operation function. @@ -51,7 +88,7 @@ def webmethod( :param required_scope: Required scope for this endpoint (e.g., 'monitoring.viewer'). """ - def wrap(func: T) -> T: + def wrap(func: CallableT) -> CallableT: webmethod_obj = WebMethod( route=route, method=method, diff --git a/tests/integration/responses/test_extra_body_shields.py b/tests/integration/responses/test_extra_body_shields.py new file mode 100644 index 000000000..3dedb287a --- /dev/null +++ b/tests/integration/responses/test_extra_body_shields.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test for extra_body parameter support with shields example. + +This test demonstrates that parameters marked with ExtraBodyField annotation +can be passed via extra_body in the client SDK and are received by the +server-side implementation. +""" + +import pytest +from llama_stack_client import APIStatusError + + +def test_shields_via_extra_body(compat_client, text_model_id): + """Test that shields parameter is received by the server and raises NotImplementedError.""" + + # Test with shields as list of strings (shield IDs) + with pytest.raises((APIStatusError, NotImplementedError)) as exc_info: + compat_client.responses.create( + model=text_model_id, + input="What is the capital of France?", + stream=False, + extra_body={"shields": ["test-shield-1", "test-shield-2"]}, + ) + + # Verify the error message indicates shields are not implemented + error_message = str(exc_info.value) + assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower()