mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 20:12:33 +00:00
fix tests
This commit is contained in:
parent
9a4d3d7576
commit
6e028023f9
4 changed files with 75 additions and 99 deletions
|
|
@ -73,7 +73,7 @@ from .types import ChatCompletionContext, ChatCompletionResult
|
||||||
from .utils import (
|
from .utils import (
|
||||||
convert_chat_choice_to_response_message,
|
convert_chat_choice_to_response_message,
|
||||||
is_function_tool_call,
|
is_function_tool_call,
|
||||||
run_multiple_guardrails,
|
run_guardrails,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||||
|
|
@ -195,7 +195,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Input safety validation - check messages before processing
|
# Input safety validation - check messages before processing
|
||||||
if self.guardrail_ids:
|
if self.guardrail_ids:
|
||||||
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||||
input_violation_message = await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
input_violation_message = await run_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||||
if input_violation_message:
|
if input_violation_message:
|
||||||
logger.info(f"Input guardrail violation: {input_violation_message}")
|
logger.info(f"Input guardrail violation: {input_violation_message}")
|
||||||
# Return refusal response immediately
|
# Return refusal response immediately
|
||||||
|
|
@ -708,7 +708,7 @@ class StreamingResponseOrchestrator:
|
||||||
# Output Safety Validation for a chunk
|
# Output Safety Validation for a chunk
|
||||||
if self.guardrail_ids:
|
if self.guardrail_ids:
|
||||||
accumulated_text = "".join(chat_response_content)
|
accumulated_text = "".join(chat_response_content)
|
||||||
violation_message = await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
violation_message = await run_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||||
if violation_message:
|
if violation_message:
|
||||||
logger.info(f"Output guardrail violation: {violation_message}")
|
logger.info(f"Output guardrail violation: {violation_message}")
|
||||||
yield await self._create_refusal_response(violation_message)
|
yield await self._create_refusal_response(violation_message)
|
||||||
|
|
|
||||||
|
|
@ -313,9 +313,9 @@ def is_function_tool_call(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
async def run_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||||
"""Run multiple guardrails against messages and return violation message if blocked."""
|
"""Run guardrails against messages and return violation message if blocked."""
|
||||||
if not guardrail_ids or not messages:
|
if not messages:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Look up shields to get their provider_resource_id (actual model ID)
|
# Look up shields to get their provider_resource_id (actual model ID)
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,13 @@ from unittest.mock import AsyncMock
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||||
|
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||||
OpenAIResponsesImpl,
|
OpenAIResponsesImpl,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||||
extract_guardrail_ids,
|
extract_guardrail_ids,
|
||||||
|
run_guardrails,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -84,3 +86,70 @@ def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||||
extract_guardrail_ids(guardrails)
|
extract_guardrail_ids(guardrails)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_safety_api():
|
||||||
|
"""Create mock safety API for guardrails testing."""
|
||||||
|
safety_api = AsyncMock()
|
||||||
|
# Mock the routing table and shields list for guardrails lookup
|
||||||
|
safety_api.routing_table = AsyncMock()
|
||||||
|
shield = AsyncMock()
|
||||||
|
shield.identifier = "llama-guard"
|
||||||
|
shield.provider_resource_id = "llama-guard-model"
|
||||||
|
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||||
|
return safety_api
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_guardrails_no_violation(mock_safety_api):
|
||||||
|
"""Test guardrails validation with no violations."""
|
||||||
|
text = "Hello world"
|
||||||
|
guardrail_ids = ["llama-guard"]
|
||||||
|
|
||||||
|
# Mock moderation to return non-flagged content
|
||||||
|
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
|
||||||
|
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
|
||||||
|
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||||
|
|
||||||
|
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
# Verify run_moderation was called with the correct model
|
||||||
|
mock_safety_api.run_moderation.assert_called_once()
|
||||||
|
call_args = mock_safety_api.run_moderation.call_args
|
||||||
|
assert call_args[1]["model"] == "llama-guard-model"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_guardrails_with_violation(mock_safety_api):
|
||||||
|
"""Test guardrails validation with safety violation."""
|
||||||
|
text = "Harmful content"
|
||||||
|
guardrail_ids = ["llama-guard"]
|
||||||
|
|
||||||
|
# Mock moderation to return flagged content
|
||||||
|
flagged_result = ModerationObjectResults(
|
||||||
|
flagged=True,
|
||||||
|
categories={"violence": True},
|
||||||
|
user_message="Content flagged by moderation",
|
||||||
|
metadata={"violation_type": ["S1"]},
|
||||||
|
)
|
||||||
|
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
||||||
|
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||||
|
|
||||||
|
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||||
|
|
||||||
|
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_run_guardrails_empty_inputs(mock_safety_api):
|
||||||
|
"""Test guardrails validation with empty inputs."""
|
||||||
|
# Test empty guardrail_ids
|
||||||
|
result = await run_guardrails(mock_safety_api, "test", [])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Test empty text
|
||||||
|
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
# Test both empty
|
||||||
|
result = await run_guardrails(mock_safety_api, "", [])
|
||||||
|
assert result is None
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,8 @@ from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.apis.agents.openai_responses import (
|
|
||||||
OpenAIResponseText,
|
|
||||||
)
|
|
||||||
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
|
||||||
from llama_stack.apis.tools import ToolDef
|
from llama_stack.apis.tools import ToolDef
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||||
StreamingResponseOrchestrator,
|
|
||||||
convert_tooldef_to_chat_tool,
|
convert_tooldef_to_chat_tool,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
|
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
|
||||||
|
|
@ -76,91 +71,3 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
||||||
assert tags_param["type"] == "array"
|
assert tags_param["type"] == "array"
|
||||||
assert "items" in tags_param, "items field should be preserved for array parameters"
|
assert "items" in tags_param, "items field should be preserved for array parameters"
|
||||||
assert tags_param["items"] == {"type": "string"}
|
assert tags_param["items"] == {"type": "string"}
|
||||||
|
|
||||||
|
|
||||||
async def test_apply_guardrails_no_violation(mock_safety_api, mock_inference_api, mock_context):
|
|
||||||
"""Test guardrails validation with no violations."""
|
|
||||||
text = "Hello world"
|
|
||||||
guardrail_ids = ["llama-guard"]
|
|
||||||
|
|
||||||
# Mock successful guardrails validation (no violation)
|
|
||||||
mock_response = AsyncMock()
|
|
||||||
mock_response.violation = None
|
|
||||||
mock_safety_api.run_shield.return_value = mock_response
|
|
||||||
|
|
||||||
# Create orchestrator with safety components
|
|
||||||
orchestrator = StreamingResponseOrchestrator(
|
|
||||||
inference_api=mock_inference_api,
|
|
||||||
ctx=mock_context,
|
|
||||||
response_id="test_id",
|
|
||||||
created_at=1234567890,
|
|
||||||
text=OpenAIResponseText(),
|
|
||||||
max_infer_iters=5,
|
|
||||||
tool_executor=AsyncMock(),
|
|
||||||
safety_api=mock_safety_api,
|
|
||||||
guardrail_ids=guardrail_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await orchestrator._apply_guardrails(text)
|
|
||||||
|
|
||||||
assert result is None
|
|
||||||
# Verify run_moderation was called with the correct model
|
|
||||||
mock_safety_api.run_moderation.assert_called_once()
|
|
||||||
# Get the actual call arguments
|
|
||||||
call_args = mock_safety_api.run_moderation.call_args
|
|
||||||
assert call_args[1]["model"] == "llama-guard-model" # The provider_resource_id from our mock
|
|
||||||
|
|
||||||
|
|
||||||
async def test_apply_guardrails_with_violation(mock_safety_api, mock_inference_api, mock_context):
|
|
||||||
"""Test guardrails validation with safety violation."""
|
|
||||||
text = "Harmful content"
|
|
||||||
guardrail_ids = ["llama-guard"]
|
|
||||||
|
|
||||||
# Mock moderation to return flagged content
|
|
||||||
flagged_result = ModerationObjectResults(
|
|
||||||
flagged=True, categories={"violence": True}, user_message="Content flagged by moderation"
|
|
||||||
)
|
|
||||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
|
||||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
|
||||||
|
|
||||||
# Create orchestrator with safety components
|
|
||||||
orchestrator = StreamingResponseOrchestrator(
|
|
||||||
inference_api=mock_inference_api,
|
|
||||||
ctx=mock_context,
|
|
||||||
response_id="test_id",
|
|
||||||
created_at=1234567890,
|
|
||||||
text=OpenAIResponseText(),
|
|
||||||
max_infer_iters=5,
|
|
||||||
tool_executor=AsyncMock(),
|
|
||||||
safety_api=mock_safety_api,
|
|
||||||
guardrail_ids=guardrail_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await orchestrator._apply_guardrails(text)
|
|
||||||
|
|
||||||
assert result == "Content flagged by moderation (flagged for: violence)"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_apply_guardrails_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
|
|
||||||
"""Test guardrails validation with empty inputs."""
|
|
||||||
# Create orchestrator with safety components
|
|
||||||
orchestrator = StreamingResponseOrchestrator(
|
|
||||||
inference_api=mock_inference_api,
|
|
||||||
ctx=mock_context,
|
|
||||||
response_id="test_id",
|
|
||||||
created_at=1234567890,
|
|
||||||
text=OpenAIResponseText(),
|
|
||||||
max_infer_iters=5,
|
|
||||||
tool_executor=AsyncMock(),
|
|
||||||
safety_api=mock_safety_api,
|
|
||||||
guardrail_ids=[],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test empty guardrail_ids
|
|
||||||
result = await orchestrator._apply_guardrails("test")
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
# Test empty text
|
|
||||||
orchestrator.guardrail_ids = ["llama-guard"]
|
|
||||||
result = await orchestrator._apply_guardrails("")
|
|
||||||
assert result is None
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue