fix tests

This commit is contained in:
Swapna Lekkala 2025-10-14 13:44:30 -07:00
parent 9a4d3d7576
commit 6e028023f9
4 changed files with 75 additions and 99 deletions

View file

@ -9,11 +9,13 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_guardrail_ids,
run_guardrails,
)
@ -84,3 +86,70 @@ def test_extract_guardrail_ids_unknown_format(responses_impl):
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
extract_guardrail_ids(guardrails)
@pytest.fixture
def mock_safety_api():
"""Create mock safety API for guardrails testing."""
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
return safety_api
async def test_run_guardrails_no_violation(mock_safety_api):
"""Test guardrails validation with no violations."""
text = "Hello world"
guardrail_ids = ["llama-guard"]
# Mock moderation to return non-flagged content
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result is None
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model"
async def test_run_guardrails_with_violation(mock_safety_api):
"""Test guardrails validation with safety violation."""
text = "Harmful content"
guardrail_ids = ["llama-guard"]
# Mock moderation to return flagged content
flagged_result = ModerationObjectResults(
flagged=True,
categories={"violence": True},
user_message="Content flagged by moderation",
metadata={"violation_type": ["S1"]},
)
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
async def test_run_guardrails_empty_inputs(mock_safety_api):
"""Test guardrails validation with empty inputs."""
# Test empty guardrail_ids
result = await run_guardrails(mock_safety_api, "test", [])
assert result is None
# Test empty text
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
assert result is None
# Test both empty
result = await run_guardrails(mock_safety_api, "", [])
assert result is None

View file

@ -8,13 +8,8 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
)
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
StreamingResponseOrchestrator,
convert_tooldef_to_chat_tool,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
@ -76,91 +71,3 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
assert tags_param["type"] == "array"
assert "items" in tags_param, "items field should be preserved for array parameters"
assert tags_param["items"] == {"type": "string"}
async def test_apply_guardrails_no_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test guardrails validation with no violations."""
text = "Hello world"
guardrail_ids = ["llama-guard"]
# Mock successful guardrails validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
mock_safety_api.run_shield.return_value = mock_response
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
inference_api=mock_inference_api,
ctx=mock_context,
response_id="test_id",
created_at=1234567890,
text=OpenAIResponseText(),
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
guardrail_ids=guardrail_ids,
)
result = await orchestrator._apply_guardrails(text)
assert result is None
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
# Get the actual call arguments
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
async def test_apply_guardrails_with_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test guardrails validation with safety violation."""
text = "Harmful content"
guardrail_ids = ["llama-guard"]
# Mock moderation to return flagged content
flagged_result = ModerationObjectResults(
flagged=True, categories={"violence": True}, user_message="Content flagged by moderation"
)
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
inference_api=mock_inference_api,
ctx=mock_context,
response_id="test_id",
created_at=1234567890,
text=OpenAIResponseText(),
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
guardrail_ids=guardrail_ids,
)
result = await orchestrator._apply_guardrails(text)
assert result == "Content flagged by moderation (flagged for: violence)"
async def test_apply_guardrails_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
"""Test guardrails validation with empty inputs."""
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
inference_api=mock_inference_api,
ctx=mock_context,
response_id="test_id",
created_at=1234567890,
text=OpenAIResponseText(),
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
guardrail_ids=[],
)
# Test empty guardrail_ids
result = await orchestrator._apply_guardrails("test")
assert result is None
# Test empty text
orchestrator.guardrail_ids = ["llama-guard"]
result = await orchestrator._apply_guardrails("")
assert result is None