This commit is contained in:
Swapna Lekkala 2025-10-13 15:19:33 -07:00
parent da07772480
commit b5c08c72a7
4 changed files with 33 additions and 141 deletions

View file

@ -9,10 +9,9 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseContentPartRefusal,
OpenAIResponseText,
)
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
StreamingResponseOrchestrator,
@ -79,12 +78,12 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
assert tags_param["items"] == {"type": "string"}
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
async def test_apply_guardrails_no_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test guardrails validation with no violations."""
text = "Hello world"
guardrail_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
# Mock successful guardrails validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
mock_safety_api.run_shield.return_value = mock_response
@ -102,7 +101,7 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
guardrail_ids=guardrail_ids,
)
result = await orchestrator._check_input_safety(messages)
result = await orchestrator._apply_guardrails(text)
assert result is None
# Verify run_moderation was called with the correct model
@ -112,13 +111,15 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
async def test_apply_guardrails_with_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test guardrails validation with safety violation."""
text = "Harmful content"
guardrail_ids = ["llama-guard"]
# Mock moderation to return flagged content
mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True})
flagged_result = ModerationObjectResults(flagged=True, categories={"violence": True})
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
@ -133,14 +134,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference
guardrail_ids=guardrail_ids,
)
result = await orchestrator._check_input_safety(messages)
result = await orchestrator._apply_guardrails(text)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content flagged by moderation"
assert result == "Content flagged by moderation"
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with empty inputs."""
async def test_apply_guardrails_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
"""Test guardrails validation with empty inputs."""
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
inference_api=mock_inference_api,
@ -154,11 +154,11 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
guardrail_ids=[],
)
# Test empty shield_ids
result = await orchestrator._check_input_safety([UserMessage(content="test")])
# Test empty guardrail_ids
result = await orchestrator._apply_guardrails("test")
assert result is None
# Test empty messages
# Test empty text
orchestrator.guardrail_ids = ["llama-guard"]
result = await orchestrator._check_input_safety([])
result = await orchestrator._apply_guardrails("")
assert result is None