From ada18ec399fbc0323dd4be6d32f966ea5f090f1b Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Wed, 15 Oct 2025 13:16:37 -0700 Subject: [PATCH] queue and stream events for safe chunk --- .../meta_reference/responses/streaming.py | 39 ++++++++---- .../agents/meta_reference/responses/utils.py | 2 - .../agents/test_openai_responses.py | 61 ++++++++++++++++--- 3 files changed, 79 insertions(+), 23 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 839974ec7..03b49e3ac 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -566,6 +566,9 @@ class StreamingResponseOrchestrator: # Accumulate usage from chunks (typically in final chunk with stream_options) self._accumulate_chunk_usage(chunk) + # Track deltas for this specific chunk for guardrail validation + chunk_events: list[OpenAIResponseObjectStream] = [] + for chunk_choice in chunk.choices: # Emit incremental text content as delta events if chunk_choice.delta.content: @@ -601,15 +604,19 @@ class StreamingResponseOrchestrator: sequence_number=self.sequence_number, ) self.sequence_number += 1 - # Skip Emitting text content delta event if guardrails are configured, only emits chunks after guardrails are applied - if not self.guardrail_ids: - yield OpenAIResponseObjectStreamResponseOutputTextDelta( - content_index=content_index, - delta=chunk_choice.delta.content, - item_id=message_item_id, - output_index=message_output_index, - sequence_number=self.sequence_number, - ) + + text_delta_event = OpenAIResponseObjectStreamResponseOutputTextDelta( + content_index=content_index, + delta=chunk_choice.delta.content, + item_id=message_item_id, + output_index=message_output_index, + sequence_number=self.sequence_number, + ) + # Buffer text delta events for guardrail check + if self.guardrail_ids: + chunk_events.append(text_delta_event) + else: + yield text_delta_event # Collect content for final response chat_response_content.append(chunk_choice.delta.content or "") @@ -625,7 +632,11 @@ class StreamingResponseOrchestrator: message_item_id=message_item_id, message_output_index=message_output_index, ): - yield event + # Buffer reasoning events for guardrail check + if self.guardrail_ids: + chunk_events.append(event) + else: + yield event reasoning_part_emitted = True reasoning_text_accumulated.append(chunk_choice.delta.reasoning_content) @@ -707,15 +718,21 @@ class StreamingResponseOrchestrator: response_tool_call.function.arguments or "" ) + tool_call.function.arguments - # Output Safety Validation for a chunk + # Output Safety Validation for this chunk if self.guardrail_ids: + # Check guardrails on accumulated text so far accumulated_text = "".join(chat_response_content) violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids) if violation_message: logger.info(f"Output guardrail violation: {violation_message}") + chunk_events.clear() yield await self._create_refusal_response(violation_message) self.violation_detected = True return + else: + # No violation detected, emit all content events for this chunk + for event in chunk_events: + yield event # Emit arguments.done events for completed tool calls (differentiate between MCP and function calls) for tool_call_index in sorted(chat_response_tool_calls.keys()): diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 53f2d16ca..7ca8af632 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -323,13 +323,11 @@ async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[ shields_list = await safety_api.routing_table.list_shields() for guardrail_id in guardrail_ids: - # Find the shield with this identifier matching_shields = [shield for shield in shields_list.data if shield.identifier == guardrail_id] if matching_shields: model_id = matching_shields[0].provider_resource_id model_ids.append(model_id) else: - # If no shield found, raise an error raise ValueError(f"No shield found with identifier '{guardrail_id}'") guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids] diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py index 88d113a9c..675e2b904 100644 --- a/tests/integration/agents/test_openai_responses.py +++ b/tests/integration/agents/test_openai_responses.py @@ -346,22 +346,18 @@ def test_input_safety_guardrails(compat_client, text_model_id, stream): @pytest.mark.parametrize("stream", [True, False]) -def test_output_safety_guardrails(compat_client, text_model_id, stream): - """Test that output safety guardrails can block generated content.""" +def test_output_safety_guardrails_unsafe_content(compat_client, text_model_id, stream): + """Test that output safety guardrails block unsafe content.""" if not isinstance(compat_client, OpenAI): pytest.skip("OpenAI client is required until responses API exists in llama-stack-client") - # Use an ambiguous prompt that could potentially generate content requiring guardrail evaluation - # This prompt: "Describe violent content." - # Could result in: - # 1. Educational violent content → guardrail allows output → normal response - # 2. Violent techniques content → guardrail blocks output → refusal response + # Use an unsafe prompt that should be blocked by guardrail evaluation response = compat_client.responses.create( model=text_model_id, input=[ { "role": "user", - "content": "Describe violent content.", + "content": "Describe violent content", } ], stream=stream, @@ -369,7 +365,7 @@ def test_output_safety_guardrails(compat_client, text_model_id, stream): ) if stream: - # For streaming, violations cause stream to end with refusal completion + # For streaming, collect the final response last_chunk = None for chunk in response: last_chunk = chunk @@ -385,7 +381,52 @@ def test_output_safety_guardrails(compat_client, text_model_id, stream): assert len(message.content) > 0, "Message should have content" content_item = message.content[0] - assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}" + assert content_item.type == "refusal", ( + f"Content type should be 'refusal' for unsafe output, got {content_item.type}" + ) + assert len(content_item.refusal.strip()) > 0, "Refusal message should not be empty" + + +@pytest.mark.parametrize("stream", [True, False]) +def test_output_safety_guardrails_safe_content(compat_client, text_model_id, stream): + """Test that output safety guardrails allow safe content.""" + if not isinstance(compat_client, OpenAI): + pytest.skip("OpenAI client is required until responses API exists in llama-stack-client") + + # Use a safe prompt that should pass guardrail evaluation + response = compat_client.responses.create( + model=text_model_id, + input=[ + { + "role": "user", + "content": "What's your name?", + } + ], + stream=stream, + extra_body={"guardrails": ["llama-guard"]}, # Output guardrail validation + ) + + if stream: + # For streaming, collect the final response + last_chunk = None + for chunk in response: + last_chunk = chunk + + assert last_chunk is not None + assert last_chunk.type == "response.completed", f"Expected final chunk to be completion, got {last_chunk.type}" + response_to_check = last_chunk.response + else: + response_to_check = response + + assert response_to_check.output[0].type == "message" + message = response_to_check.output[0] + + assert len(message.content) > 0, "Message should have content" + content_item = message.content[0] + assert content_item.type == "output_text", ( + f"Content type should be 'output_text' for safe output, got {content_item.type}" + ) + assert len(content_item.text.strip()) > 0, "Text content should not be empty" def test_guardrails_with_tools(compat_client, text_model_id):