Merge branch 'responses-and-safety' into responses-and-safety

This commit is contained in:
slekkala1 2025-10-10 14:12:53 -07:00 committed by GitHub
commit 74cb26a021
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 91 additions and 1 deletions

View file

@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import (
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseContentPartRefusal,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -38,7 +39,9 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
@ -84,6 +87,9 @@ def mock_vector_io_api():
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
def mock_conversations_api():
"""Mock conversations API for testing."""
mock_api = AsyncMock()
@ -103,6 +109,7 @@ def openai_responses_impl(
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_safety_api,
mock_conversations_api,
mock_safety_api,
):
@ -112,6 +119,7 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
)
@ -1090,3 +1098,52 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
model=model,
text=OpenAIResponseText(format={"type": "invalid"}),
)
async def test_check_input_safety_no_violation(openai_responses_impl):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert result is None
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
shield_id="llama-guard", messages=messages, params={}
)
async def test_check_input_safety_with_violation(openai_responses_impl):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
assert result.type == "refusal"
async def test_check_input_safety_empty_inputs(openai_responses_impl):
"""Test input shield validation with empty inputs."""
# Test empty shield_ids
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
assert result is None
# Test empty messages
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
assert result is None