mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 06:02:36 +00:00
address comments
This commit is contained in:
parent
c66757ea4d
commit
0efdc46d89
25 changed files with 1251 additions and 77 deletions
|
|
@ -18,7 +18,6 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
|
@ -39,9 +38,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
|
|
@ -1093,52 +1090,3 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
|
|||
model=model,
|
||||
text=OpenAIResponseText(format={"type": "invalid"}),
|
||||
)
|
||||
|
||||
|
||||
async def test_check_input_safety_no_violation(openai_responses_impl):
|
||||
"""Test input shield validation with no violations."""
|
||||
messages = [UserMessage(content="Hello world")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock successful shield validation (no violation)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = None
|
||||
openai_responses_impl.safety_api.run_shield.return_value = mock_response
|
||||
|
||||
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
|
||||
|
||||
assert result is None
|
||||
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
|
||||
shield_id="llama-guard", messages=messages, params={}
|
||||
)
|
||||
|
||||
|
||||
async def test_check_input_safety_with_violation(openai_responses_impl):
|
||||
"""Test input shield validation with safety violation."""
|
||||
messages = [UserMessage(content="Harmful content")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock shield violation
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
|
||||
)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = violation
|
||||
openai_responses_impl.safety_api.run_shield.return_value = mock_response
|
||||
|
||||
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
|
||||
|
||||
assert isinstance(result, OpenAIResponseContentPartRefusal)
|
||||
assert result.refusal == "Content violates safety guidelines"
|
||||
assert result.type == "refusal"
|
||||
|
||||
|
||||
async def test_check_input_safety_empty_inputs(openai_responses_impl):
|
||||
"""Test input shield validation with empty inputs."""
|
||||
# Test empty shield_ids
|
||||
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
|
||||
assert result is None
|
||||
|
||||
# Test empty messages
|
||||
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
|
||||
assert result is None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue