fix tests and remove unwanted changes

This commit is contained in:
Swapna Lekkala 2025-10-13 13:41:35 -07:00
parent c10db23d7a
commit 06dcfd1915
8 changed files with 36 additions and 30 deletions

View file

@ -13,7 +13,6 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
)
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
StreamingResponseOrchestrator,
@ -25,6 +24,14 @@ from llama_stack.providers.inline.agents.meta_reference.responses.types import C
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
# Mock run_moderation to return non-flagged result by default
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
return safety_api
@ -75,7 +82,7 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
guardrail_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
@ -92,27 +99,26 @@ async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_a
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=shield_ids,
guardrail_ids=guardrail_ids,
)
result = await orchestrator._check_input_safety(messages)
assert result is None
mock_safety_api.run_shield.assert_called_once_with(shield_id="llama-guard", messages=messages, params={})
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
# Get the actual call arguments
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
guardrail_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
mock_safety_api.run_shield.return_value = mock_response
# Mock moderation to return flagged content
mock_safety_api.run_moderation.return_value = AsyncMock(flagged=True, categories={"violence": True})
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
@ -124,13 +130,13 @@ async def test_check_input_safety_with_violation(mock_safety_api, mock_inference
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=shield_ids,
guardrail_ids=guardrail_ids,
)
result = await orchestrator._check_input_safety(messages)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
assert result.refusal == "Content flagged by moderation"
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
@ -145,7 +151,7 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=[],
guardrail_ids=[],
)
# Test empty shield_ids
@ -153,6 +159,6 @@ async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_a
assert result is None
# Test empty messages
orchestrator.shield_ids = ["llama-guard"]
orchestrator.guardrail_ids = ["llama-guard"]
result = await orchestrator._check_input_safety([])
assert result is None