mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 14:52:40 +00:00
fix tests
This commit is contained in:
parent
9a4d3d7576
commit
6e028023f9
4 changed files with 75 additions and 99 deletions
|
|
@ -9,11 +9,13 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_guardrail_ids,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -84,3 +86,70 @@ def test_extract_guardrail_ids_unknown_format(responses_impl):
|
|||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
"""Create mock safety API for guardrails testing."""
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
return safety_api
|
||||
|
||||
|
||||
async def test_run_guardrails_no_violation(mock_safety_api):
|
||||
"""Test guardrails validation with no violations."""
|
||||
text = "Hello world"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return non-flagged content
|
||||
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result is None
|
||||
# Verify run_moderation was called with the correct model
|
||||
mock_safety_api.run_moderation.assert_called_once()
|
||||
call_args = mock_safety_api.run_moderation.call_args
|
||||
assert call_args[1]["model"] == "llama-guard-model"
|
||||
|
||||
|
||||
async def test_run_guardrails_with_violation(mock_safety_api):
|
||||
"""Test guardrails validation with safety violation."""
|
||||
text = "Harmful content"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return flagged content
|
||||
flagged_result = ModerationObjectResults(
|
||||
flagged=True,
|
||||
categories={"violence": True},
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"violation_type": ["S1"]},
|
||||
)
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
|
||||
|
||||
|
||||
async def test_run_guardrails_empty_inputs(mock_safety_api):
|
||||
"""Test guardrails validation with empty inputs."""
|
||||
# Test empty guardrail_ids
|
||||
result = await run_guardrails(mock_safety_api, "test", [])
|
||||
assert result is None
|
||||
|
||||
# Test empty text
|
||||
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test both empty
|
||||
result = await run_guardrails(mock_safety_api, "", [])
|
||||
assert result is None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue