diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 07b6b285d..798393e28 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -64,7 +64,6 @@ from llama_stack.apis.inference import ( OpenAIChatCompletionToolCall, OpenAIChoice, OpenAIMessageParam, - OpenAIUserMessageParam, ) from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str @@ -136,33 +135,16 @@ class StreamingResponseOrchestrator: # Track if we've sent a refusal response self.violation_detected = False - async def _check_input_safety( - self, messages: list[OpenAIUserMessageParam] - ) -> OpenAIResponseContentPartRefusal | None: - """Validate input messages against guardrails. Returns refusal content if violation found.""" - combined_text = interleaved_content_as_str([msg.content for msg in messages]) - - if not combined_text: + async def _apply_guardrails(self, text: str, context: str = "content") -> str | None: + """Apply guardrails to text content. Returns violation message if blocked.""" + if not self.guardrail_ids or not text: return None try: - await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids) + await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids) except SafetyException as e: - logger.info(f"Input guardrail violation: {e.violation.user_message}") - return OpenAIResponseContentPartRefusal( - refusal=e.violation.user_message or "Content blocked by safety guardrails" - ) - - async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None: - """Check accumulated streaming text content against guardrails. Returns violation message if blocked.""" - if not self.guardrail_ids or not accumulated_text: - return None - - try: - await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids) - except SafetyException as e: - logger.info(f"Output guardrail violation: {e.violation.user_message}") - return e.violation.user_message or "Generated content blocked by safety guardrails" + logger.info(f"{context.capitalize()} guardrail violation: {e.violation.user_message}") + return e.violation.user_message or f"{context.capitalize()} blocked by safety guardrails" async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: """Create a refusal response to replace streaming content.""" @@ -224,10 +206,11 @@ class StreamingResponseOrchestrator: # Input safety validation - check messages before processing if self.guardrail_ids: - input_refusal = await self._check_input_safety(self.ctx.messages) - if input_refusal: + combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages]) + input_violation_message = await self._apply_guardrails(combined_text, "input") + if input_violation_message: # Return refusal response immediately - yield await self._create_refusal_response(input_refusal.refusal) + yield await self._create_refusal_response(input_violation_message) return async for stream_event in self._process_tools(output_messages): @@ -733,10 +716,10 @@ class StreamingResponseOrchestrator: response_tool_call.function.arguments or "" ) + tool_call.function.arguments - # Safety check after processing all choices in this chunk + # Output Safety Validation for a chunk if chat_response_content: accumulated_text = "".join(chat_response_content) - violation_message = await self._check_output_stream_chunk_safety(accumulated_text) + violation_message = await self._apply_guardrails(accumulated_text, "output") if violation_message: yield await self._create_refusal_response(violation_message) self.violation_detected = True diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index fa6cb115f..756e8405e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -365,20 +365,3 @@ def extract_guardrail_ids(guardrails: list | None) -> list[str]: raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec") return guardrail_ids - - -def extract_text_content(content: str | list | None) -> str | None: - """Extract text content from OpenAI message content (string or complex structure).""" - if isinstance(content, str): - return content - elif isinstance(content, list): - # Handle complex content - extract text parts only - text_parts = [] - for part in content: - if hasattr(part, "text"): - text_parts.append(part.text) - elif hasattr(part, "type") and part.type == "refusal": - # Skip refusal parts - don't validate them again - continue - return " ".join(text_parts) if text_parts else None - return None diff --git a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py index 65a44be59..288151432 100644 --- a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py +++ b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock import pytest @@ -14,7 +14,6 @@ from llama_stack.providers.inline.agents.meta_reference.responses.openai_respons ) from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( extract_guardrail_ids, - extract_text_content, ) @@ -85,76 +84,3 @@ def test_extract_guardrail_ids_unknown_format(responses_impl): guardrails = ["valid-guardrail", unknown_object, "another-guardrail"] with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"): extract_guardrail_ids(guardrails) - - -def test_extract_text_content_string(responses_impl): - """Test extraction from simple string content.""" - content = "Hello world" - result = extract_text_content(content) - assert result == "Hello world" - - -def test_extract_text_content_list_with_text(responses_impl): - """Test extraction from list content with text parts.""" - content = [ - MagicMock(text="Hello "), - MagicMock(text="world"), - ] - result = extract_text_content(content) - assert result == "Hello world" - - -def test_extract_text_content_list_with_refusal(responses_impl): - """Test extraction skips refusal parts.""" - # Create text parts - text_part1 = MagicMock() - text_part1.text = "Hello" - - text_part2 = MagicMock() - text_part2.text = "world" - - # Create refusal part (no text attribute) - refusal_part = MagicMock() - refusal_part.type = "refusal" - refusal_part.refusal = "Blocked" - del refusal_part.text # Remove text attribute - - content = [text_part1, refusal_part, text_part2] - result = extract_text_content(content) - assert result == "Hello world" - - -def test_extract_text_content_empty_list(responses_impl): - """Test extraction from empty list returns None.""" - content = [] - result = extract_text_content(content) - assert result is None - - -def test_extract_text_content_no_text_parts(responses_impl): - """Test extraction with no text parts returns None.""" - # Create image part (no text attribute) - image_part = MagicMock() - image_part.type = "image" - image_part.image_url = "http://example.com" - - # Create refusal part (no text attribute) - refusal_part = MagicMock() - refusal_part.type = "refusal" - refusal_part.refusal = "Blocked" - - # Explicitly remove text attributes to simulate non-text parts - if hasattr(image_part, "text"): - delattr(image_part, "text") - if hasattr(refusal_part, "text"): - delattr(refusal_part, "text") - - content = [image_part, refusal_part] - result = extract_text_content(content) - assert result is None - - -def test_extract_text_content_none_input(responses_impl): - """Test extraction with None input returns None.""" - result = extract_text_content(None) - assert result is None diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py index 8f657eed3..db018508f 100644 --- a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py +++ b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py @@ -9,10 +9,9 @@ from unittest.mock import AsyncMock import pytest from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseContentPartRefusal, OpenAIResponseText, ) -from llama_stack.apis.inference import UserMessage +from llama_stack.apis.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.tools import ToolDef from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( StreamingResponseOrchestrator, @@ -79,12 +78,12 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field(): assert tags_param["items"] == {"type": "string"} -async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context): - """Test input shield validation with no violations.""" - messages = [UserMessage(content="Hello world")] +async def test_apply_guardrails_no_violation(mock_safety_api, mock_inference_api, mock_context): + """Test guardrails validation with no violations.""" + text = "Hello world" guardrail_ids = ["llama-guard"] - # Mock successful shield validation (no violation) + # Mock successful guardrails validation (no violation) mock_response = AsyncMock() mock_response.violation = None mock_safety_api.run_shield.return_value = mock_response @@ -102,7 +101,7 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a guardrail_ids=guardrail_ids, ) - result = await orchestrator._check_input_safety(messages) + result = await orchestrator._apply_guardrails(text) assert result is None # Verify run_moderation was called with the correct model @@ -112,13 +111,15 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock -async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context): - """Test input shield validation with safety violation.""" - messages = [UserMessage(content="Harmful content")] +async def test_apply_guardrails_with_violation(mock_safety_api, mock_inference_api, mock_context): + """Test guardrails validation with safety violation.""" + text = "Harmful content" guardrail_ids = ["llama-guard"] # Mock moderation to return flagged content - mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True}) + flagged_result = ModerationObjectResults(flagged=True, categories={"violence": True}) + mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result]) + mock_safety_api.run_moderation.return_value = mock_moderation_object # Create orchestrator with safety components orchestrator = StreamingResponseOrchestrator( @@ -133,14 +134,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference guardrail_ids=guardrail_ids, ) - result = await orchestrator._check_input_safety(messages) + result = await orchestrator._apply_guardrails(text) - assert isinstance(result, OpenAIResponseContentPartRefusal) - assert result.refusal == "Content flagged by moderation" + assert result == "Content flagged by moderation" -async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context): - """Test input shield validation with empty inputs.""" +async def test_apply_guardrails_empty_inputs(mock_safety_api, mock_inference_api, mock_context): + """Test guardrails validation with empty inputs.""" # Create orchestrator with safety components orchestrator = StreamingResponseOrchestrator( inference_api=mock_inference_api, @@ -154,11 +154,11 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a guardrail_ids=[], ) - # Test empty shield_ids - result = await orchestrator._check_input_safety([UserMessage(content="test")]) + # Test empty guardrail_ids + result = await orchestrator._apply_guardrails("test") assert result is None - # Test empty messages + # Test empty text orchestrator.guardrail_ids = ["llama-guard"] - result = await orchestrator._check_input_safety([]) + result = await orchestrator._apply_guardrails("") assert result is None