From 7b5dd54b7a06ca3a84b427884dd8fa2b133e4ff7 Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Fri, 10 Oct 2025 09:38:33 -0700 Subject: [PATCH] fix tests --- .../meta_reference/responses/openai_responses.py | 2 +- .../agents/meta_reference/responses/streaming.py | 7 ++++++- .../inline/agents/meta_reference/responses/utils.py | 2 +- .../meta_reference/test_responses_safety_utils.py | 10 +++++----- 4 files changed, 13 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 83a321eae..edb843158 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -345,7 +345,7 @@ class OpenAIResponsesImpl: return # Structured outputs - response_format = convert_response_text_to_chat_response_format(text) + response_format = await convert_response_text_to_chat_response_format(text) ctx = ChatCompletionContext( model=model, diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index c95a2e732..3cf8d3b0c 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -13,8 +13,8 @@ from llama_stack.apis.agents.openai_responses import ( ApprovalFilter, MCPListToolsTool, OpenAIResponseContentPartOutputText, - OpenAIResponseError, OpenAIResponseContentPartRefusal, + OpenAIResponseError, OpenAIResponseInputTool, OpenAIResponseInputToolMCP, OpenAIResponseMCPApprovalRequest, @@ -226,6 +226,11 @@ class StreamingResponseOrchestrator: completion_result_data = stream_event_or_result else: yield stream_event_or_result + + # If violation detected, skip the rest of processing since we already sent refusal + if self.violation_detected: + return + if not completion_result_data: raise ValueError("Streaming chunk processor failed to return completion data") last_completion_result = completion_result_data diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 9d691c3b9..74d2c48b4 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -178,7 +178,7 @@ async def convert_response_input_to_chat_messages( pass else: content = await convert_response_content_to_chat_content(input_item.content) - message_type = get_message_type_by_role(input_item.role) + message_type = await get_message_type_by_role(input_item.role) if message_type is None: raise ValueError( f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" diff --git a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py index c489337bb..1cf25eb74 100644 --- a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py +++ b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py @@ -84,14 +84,14 @@ def test_extract_shield_ids_empty_list(responses_impl): assert result == [] -def test_extract_shield_ids_unknown_format(responses_impl, caplog): - """Test extraction with unknown shield format logs warning.""" +def test_extract_shield_ids_unknown_format(responses_impl): + """Test extraction with unknown shield format raises ValueError.""" # Create an object that's neither string nor ResponseShieldSpec unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec shields = ["valid-shield", unknown_object, "another-shield"] - result = extract_shield_ids(shields) - assert result == ["valid-shield", "another-shield"] - assert "Unknown shield format" in caplog.text + + with pytest.raises(ValueError, match="Unknown shield format.*expected str or ResponseShieldSpec"): + extract_shield_ids(shields) def test_extract_text_content_string(responses_impl):