diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index e722e4de6..0d9f9f134 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -374,6 +374,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): body = options.params or {} body |= options.json_data or {} + # Merge extra_json parameters (extra_body from SDK is converted to extra_json) + if hasattr(options, "extra_json") and options.extra_json: + body |= options.extra_json + matched_func, path_params, route_path, webmethod = find_matching_route(options.method, path, self.route_impls) body |= path_params diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8bdde86b0..5431e8f28 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -329,6 +329,7 @@ class MetaReferenceAgentsImpl(Agents): tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, + shields: list | None = None, ) -> OpenAIResponseObject: return await self.openai_responses_impl.create_openai_response( input, @@ -342,6 +343,7 @@ class MetaReferenceAgentsImpl(Agents): tools, include, max_infer_iters, + shields, ) async def list_openai_responses( diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 352be3ded..8ccdcb0e1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -208,10 +208,15 @@ class OpenAIResponsesImpl: tools: list[OpenAIResponseInputTool] | None = None, include: list[str] | None = None, max_infer_iters: int | None = 10, + shields: list | None = None, ): stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text + # Shields parameter received via extra_body - not yet implemented + if shields is not None: + raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider") + stream_gen = self._create_streaming_response( input=input, model=model, diff --git a/tests/integration/responses/test_extra_body_shields.py b/tests/integration/responses/test_extra_body_shields.py new file mode 100644 index 000000000..b0c6ec39a --- /dev/null +++ b/tests/integration/responses/test_extra_body_shields.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +""" +Test for extra_body parameter support with shields example. + +This test demonstrates that parameters marked with ExtraBodyField annotation +can be passed via extra_body in the client SDK and are received by the +server-side implementation. +""" + +import pytest +from llama_stack_client import APIStatusError + + +def test_shields_via_extra_body(compat_client, text_model_id): + """Test that shields parameter is received by the server and raises NotImplementedError.""" + + # Test with shields as list of strings (shield IDs) + with pytest.raises((APIStatusError, NotImplementedError)) as exc_info: + compat_client.responses.create( + model=text_model_id, + input="What is the capital of France?", + stream=False, + extra_body={ + "shields": ["test-shield-1", "test-shield-2"] + } + ) + + # Verify the error message indicates shields are not implemented + error_message = str(exc_info.value) + assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower() + + + + +def test_response_without_shields_still_works(compat_client, text_model_id): + """Test that responses still work without shields parameter (backwards compatibility).""" + + response = compat_client.responses.create( + model=text_model_id, + input="Hello, world!", + stream=False, + ) + + # Verify response was created successfully + assert response.id is not None + assert response.output_text is not None + assert len(response.output_text) > 0 + + +def test_shields_parameter_received_end_to_end(compat_client, text_model_id): + """ + Test that shields parameter passed via extra_body reaches the server implementation. + + This verifies end-to-end that: + 1. The parameter can be passed via extra_body in the client SDK + 2. The parameter is properly routed through the API layers + 3. The server-side implementation receives the parameter (verified by NotImplementedError) + + The NotImplementedError proves the extra_body parameter reached the implementation, + as opposed to being rejected earlier due to signature mismatch or validation errors. + """ + # Test with shields parameter via extra_body + with pytest.raises((APIStatusError, NotImplementedError)) as exc_info: + compat_client.responses.create( + model=text_model_id, + input="Test message for shields verification", + stream=False, + extra_body={ + "shields": ["shield-1", "shield-2"] + } + ) + + # The NotImplementedError proves that: + # 1. extra_body.shields was parsed and passed to the API + # 2. The server-side implementation received the shields parameter + # 3. No signature mismatch or validation errors occurred + error_message = str(exc_info.value) + assert "not yet implemented" in error_message.lower() or "not implemented" in error_message.lower()