diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index c23714617..0cb350df8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -73,7 +73,7 @@ from .types import ChatCompletionContext, ChatCompletionResult from .utils import ( convert_chat_choice_to_response_message, is_function_tool_call, - run_multiple_guardrails, + run_guardrails, ) logger = get_logger(name=__name__, category="agents::meta_reference") @@ -195,7 +195,7 @@ class StreamingResponseOrchestrator: # Input safety validation - check messages before processing if self.guardrail_ids: combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages]) - input_violation_message = await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids) + input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids) if input_violation_message: logger.info(f"Input guardrail violation: {input_violation_message}") # Return refusal response immediately @@ -708,7 +708,7 @@ class StreamingResponseOrchestrator: # Output Safety Validation for a chunk if self.guardrail_ids: accumulated_text = "".join(chat_response_content) - violation_message = await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids) + violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids) if violation_message: logger.info(f"Output guardrail violation: {violation_message}") yield await self._create_refusal_response(violation_message) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 0a7538292..53f2d16ca 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -313,9 +313,9 @@ def is_function_tool_call( return False -async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None: - """Run multiple guardrails against messages and return violation message if blocked.""" - if not guardrail_ids or not messages: +async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None: + """Run guardrails against messages and return violation message if blocked.""" + if not messages: return None # Look up shields to get their provider_resource_id (actual model ID) diff --git a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py index 288151432..9c5cc853c 100644 --- a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py +++ b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py @@ -9,11 +9,13 @@ from unittest.mock import AsyncMock import pytest from llama_stack.apis.agents.agents import ResponseGuardrailSpec +from llama_stack.apis.safety import ModerationObject, ModerationObjectResults from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( OpenAIResponsesImpl, ) from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( extract_guardrail_ids, + run_guardrails, ) @@ -84,3 +86,70 @@ def test_extract_guardrail_ids_unknown_format(responses_impl): guardrails = ["valid-guardrail", unknown_object, "another-guardrail"] with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"): extract_guardrail_ids(guardrails) + + +@pytest.fixture +def mock_safety_api(): + """Create mock safety API for guardrails testing.""" + safety_api = AsyncMock() + # Mock the routing table and shields list for guardrails lookup + safety_api.routing_table = AsyncMock() + shield = AsyncMock() + shield.identifier = "llama-guard" + shield.provider_resource_id = "llama-guard-model" + safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield]) + return safety_api + + +async def test_run_guardrails_no_violation(mock_safety_api): + """Test guardrails validation with no violations.""" + text = "Hello world" + guardrail_ids = ["llama-guard"] + + # Mock moderation to return non-flagged content + unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False}) + mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result]) + mock_safety_api.run_moderation.return_value = mock_moderation_object + + result = await run_guardrails(mock_safety_api, text, guardrail_ids) + + assert result is None + # Verify run_moderation was called with the correct model + mock_safety_api.run_moderation.assert_called_once() + call_args = mock_safety_api.run_moderation.call_args + assert call_args[1]["model"] == "llama-guard-model" + + +async def test_run_guardrails_with_violation(mock_safety_api): + """Test guardrails validation with safety violation.""" + text = "Harmful content" + guardrail_ids = ["llama-guard"] + + # Mock moderation to return flagged content + flagged_result = ModerationObjectResults( + flagged=True, + categories={"violence": True}, + user_message="Content flagged by moderation", + metadata={"violation_type": ["S1"]}, + ) + mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result]) + mock_safety_api.run_moderation.return_value = mock_moderation_object + + result = await run_guardrails(mock_safety_api, text, guardrail_ids) + + assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)" + + +async def test_run_guardrails_empty_inputs(mock_safety_api): + """Test guardrails validation with empty inputs.""" + # Test empty guardrail_ids + result = await run_guardrails(mock_safety_api, "test", []) + assert result is None + + # Test empty text + result = await run_guardrails(mock_safety_api, "", ["llama-guard"]) + assert result is None + + # Test both empty + result = await run_guardrails(mock_safety_api, "", []) + assert result is None diff --git a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py index a82b94aa3..fff29928c 100644 --- a/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py +++ b/tests/unit/providers/inline/agents/meta_reference/responses/test_streaming.py @@ -8,13 +8,8 @@ from unittest.mock import AsyncMock import pytest -from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseText, -) -from llama_stack.apis.safety import ModerationObject, ModerationObjectResults from llama_stack.apis.tools import ToolDef from llama_stack.providers.inline.agents.meta_reference.responses.streaming import ( - StreamingResponseOrchestrator, convert_tooldef_to_chat_tool, ) from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext @@ -76,91 +71,3 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field(): assert tags_param["type"] == "array" assert "items" in tags_param, "items field should be preserved for array parameters" assert tags_param["items"] == {"type": "string"} - - -async def test_apply_guardrails_no_violation(mock_safety_api, mock_inference_api, mock_context): - """Test guardrails validation with no violations.""" - text = "Hello world" - guardrail_ids = ["llama-guard"] - - # Mock successful guardrails validation (no violation) - mock_response = AsyncMock() - mock_response.violation = None - mock_safety_api.run_shield.return_value = mock_response - - # Create orchestrator with safety components - orchestrator = StreamingResponseOrchestrator( - inference_api=mock_inference_api, - ctx=mock_context, - response_id="test_id", - created_at=1234567890, - text=OpenAIResponseText(), - max_infer_iters=5, - tool_executor=AsyncMock(), - safety_api=mock_safety_api, - guardrail_ids=guardrail_ids, - ) - - result = await orchestrator._apply_guardrails(text) - - assert result is None - # Verify run_moderation was called with the correct model - mock_safety_api.run_moderation.assert_called_once() - # Get the actual call arguments - call_args = mock_safety_api.run_moderation.call_args - assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock - - -async def test_apply_guardrails_with_violation(mock_safety_api, mock_inference_api, mock_context): - """Test guardrails validation with safety violation.""" - text = "Harmful content" - guardrail_ids = ["llama-guard"] - - # Mock moderation to return flagged content - flagged_result = ModerationObjectResults( - flagged=True, categories={"violence": True}, user_message="Content flagged by moderation" - ) - mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result]) - mock_safety_api.run_moderation.return_value = mock_moderation_object - - # Create orchestrator with safety components - orchestrator = StreamingResponseOrchestrator( - inference_api=mock_inference_api, - ctx=mock_context, - response_id="test_id", - created_at=1234567890, - text=OpenAIResponseText(), - max_infer_iters=5, - tool_executor=AsyncMock(), - safety_api=mock_safety_api, - guardrail_ids=guardrail_ids, - ) - - result = await orchestrator._apply_guardrails(text) - - assert result == "Content flagged by moderation (flagged for: violence)" - - -async def test_apply_guardrails_empty_inputs(mock_safety_api, mock_inference_api, mock_context): - """Test guardrails validation with empty inputs.""" - # Create orchestrator with safety components - orchestrator = StreamingResponseOrchestrator( - inference_api=mock_inference_api, - ctx=mock_context, - response_id="test_id", - created_at=1234567890, - text=OpenAIResponseText(), - max_infer_iters=5, - tool_executor=AsyncMock(), - safety_api=mock_safety_api, - guardrail_ids=[], - ) - - # Test empty guardrail_ids - result = await orchestrator._apply_guardrails("test") - assert result is None - - # Test empty text - orchestrator.guardrail_ids = ["llama-guard"] - result = await orchestrator._apply_guardrails("") - assert result is None