mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-16 12:39:27 +00:00
feat: Add responses and safety impl extra_body
This commit is contained in:
parent
0a96a7faa5
commit
ad4362e48d
163 changed files with 29338 additions and 141 deletions
|
|
@ -4,10 +4,44 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseText,
|
||||
)
|
||||
from llama_stack.apis.inference import UserMessage
|
||||
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
StreamingResponseOrchestrator,
|
||||
convert_tooldef_to_chat_tool,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
inference_api = AsyncMock()
|
||||
return inference_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = AsyncMock(spec=ChatCompletionContext)
|
||||
# Add required attributes that StreamingResponseOrchestrator expects
|
||||
context.tool_context = AsyncMock()
|
||||
context.tool_context.previous_tools = {}
|
||||
context.messages = []
|
||||
return context
|
||||
|
||||
|
||||
def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
||||
|
|
@ -36,3 +70,89 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
|||
assert tags_param["type"] == "array"
|
||||
assert "items" in tags_param, "items field should be preserved for array parameters"
|
||||
assert tags_param["items"] == {"type": "string"}
|
||||
|
||||
|
||||
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
|
||||
"""Test input shield validation with no violations."""
|
||||
messages = [UserMessage(content="Hello world")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock successful shield validation (no violation)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = None
|
||||
mock_safety_api.run_shield.return_value = mock_response
|
||||
|
||||
# Create orchestrator with safety components
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
inference_api=mock_inference_api,
|
||||
ctx=mock_context,
|
||||
response_id="test_id",
|
||||
created_at=1234567890,
|
||||
text=OpenAIResponseText(),
|
||||
max_infer_iters=5,
|
||||
tool_executor=AsyncMock(),
|
||||
safety_api=mock_safety_api,
|
||||
shield_ids=shield_ids,
|
||||
)
|
||||
|
||||
result = await orchestrator._check_input_safety(messages)
|
||||
|
||||
assert result is None
|
||||
mock_safety_api.run_shield.assert_called_once_with(shield_id="llama-guard", messages=messages, params={})
|
||||
|
||||
|
||||
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
|
||||
"""Test input shield validation with safety violation."""
|
||||
messages = [UserMessage(content="Harmful content")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock shield violation
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
|
||||
)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = violation
|
||||
mock_safety_api.run_shield.return_value = mock_response
|
||||
|
||||
# Create orchestrator with safety components
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
inference_api=mock_inference_api,
|
||||
ctx=mock_context,
|
||||
response_id="test_id",
|
||||
created_at=1234567890,
|
||||
text=OpenAIResponseText(),
|
||||
max_infer_iters=5,
|
||||
tool_executor=AsyncMock(),
|
||||
safety_api=mock_safety_api,
|
||||
shield_ids=shield_ids,
|
||||
)
|
||||
|
||||
result = await orchestrator._check_input_safety(messages)
|
||||
|
||||
assert isinstance(result, OpenAIResponseContentPartRefusal)
|
||||
assert result.refusal == "Content violates safety guidelines"
|
||||
|
||||
|
||||
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
|
||||
"""Test input shield validation with empty inputs."""
|
||||
# Create orchestrator with safety components
|
||||
orchestrator = StreamingResponseOrchestrator(
|
||||
inference_api=mock_inference_api,
|
||||
ctx=mock_context,
|
||||
response_id="test_id",
|
||||
created_at=1234567890,
|
||||
text=OpenAIResponseText(),
|
||||
max_infer_iters=5,
|
||||
tool_executor=AsyncMock(),
|
||||
safety_api=mock_safety_api,
|
||||
shield_ids=[],
|
||||
)
|
||||
|
||||
# Test empty shield_ids
|
||||
result = await orchestrator._check_input_safety([UserMessage(content="test")])
|
||||
assert result is None
|
||||
|
||||
# Test empty messages
|
||||
orchestrator.shield_ids = ["llama-guard"]
|
||||
result = await orchestrator._check_input_safety([])
|
||||
assert result is None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue