mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 08:17:18 +00:00
feat: Add responses and safety impl with extra body
This commit is contained in:
parent
548ccff368
commit
e09401805f
15 changed files with 877 additions and 9 deletions
|
@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
from llama_stack.apis.agents import Order
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseContentPartRefusal,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
@ -38,8 +39,11 @@ from llama_stack.apis.inference import (
|
|||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
|
||||
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
|
@ -83,9 +87,20 @@ def mock_vector_io_api():
|
|||
return vector_io_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_safety_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
|
@ -93,6 +108,7 @@ def openai_responses_impl(
|
|||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
safety_api=mock_safety_api,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1066,3 +1082,57 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
|
|||
model=model,
|
||||
text=OpenAIResponseText(format={"type": "invalid"}),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Shield Validation Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def test_check_input_safety_no_violation(openai_responses_impl):
|
||||
"""Test input shield validation with no violations."""
|
||||
messages = [UserMessage(content="Hello world")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock successful shield validation (no violation)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = None
|
||||
openai_responses_impl.safety_api.run_shield.return_value = mock_response
|
||||
|
||||
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
|
||||
|
||||
assert result is None
|
||||
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
|
||||
shield_id="llama-guard", messages=messages, params={}
|
||||
)
|
||||
|
||||
|
||||
async def test_check_input_safety_with_violation(openai_responses_impl):
|
||||
"""Test input shield validation with safety violation."""
|
||||
messages = [UserMessage(content="Harmful content")]
|
||||
shield_ids = ["llama-guard"]
|
||||
|
||||
# Mock shield violation
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
|
||||
)
|
||||
mock_response = AsyncMock()
|
||||
mock_response.violation = violation
|
||||
openai_responses_impl.safety_api.run_shield.return_value = mock_response
|
||||
|
||||
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
|
||||
|
||||
assert isinstance(result, OpenAIResponseContentPartRefusal)
|
||||
assert result.refusal == "Content violates safety guidelines"
|
||||
assert result.type == "refusal"
|
||||
|
||||
|
||||
async def test_check_input_safety_empty_inputs(openai_responses_impl):
|
||||
"""Test input shield validation with empty inputs."""
|
||||
# Test empty shield_ids
|
||||
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
|
||||
assert result is None
|
||||
|
||||
# Test empty messages
|
||||
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
|
||||
assert result is None
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseShieldSpec
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
convert_openai_to_inference_messages,
|
||||
extract_shield_ids,
|
||||
extract_text_content,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
"""Create mock APIs for testing."""
|
||||
return {
|
||||
"inference_api": AsyncMock(),
|
||||
"tool_groups_api": AsyncMock(),
|
||||
"tool_runtime_api": AsyncMock(),
|
||||
"responses_store": AsyncMock(),
|
||||
"vector_io_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl(mock_apis):
|
||||
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
|
||||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Shield ID Extraction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string shield IDs."""
|
||||
shields = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_shield_ids(shields)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseShieldSpec objects."""
|
||||
shields = [
|
||||
ResponseShieldSpec(type="llama-guard"),
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
shields = [
|
||||
"llama-guard",
|
||||
ResponseShieldSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_shield_ids(shields)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_shield_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_shield_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_shield_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_shield_ids_unknown_format(responses_impl, caplog):
|
||||
"""Test extraction with unknown shield format logs warning."""
|
||||
# Create an object that's neither string nor ResponseShieldSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
|
||||
shields = ["valid-shield", unknown_object, "another-shield"]
|
||||
result = extract_shield_ids(shields)
|
||||
assert result == ["valid-shield", "another-shield"]
|
||||
assert "Unknown shield format" in caplog.text
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Text Content Extraction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_extract_text_content_string(responses_impl):
|
||||
"""Test extraction from simple string content."""
|
||||
content = "Hello world"
|
||||
result = extract_text_content(content)
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_extract_text_content_list_with_text(responses_impl):
|
||||
"""Test extraction from list content with text parts."""
|
||||
content = [
|
||||
MagicMock(text="Hello "),
|
||||
MagicMock(text="world"),
|
||||
]
|
||||
result = extract_text_content(content)
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_extract_text_content_list_with_refusal(responses_impl):
|
||||
"""Test extraction skips refusal parts."""
|
||||
# Create text parts
|
||||
text_part1 = MagicMock()
|
||||
text_part1.text = "Hello"
|
||||
|
||||
text_part2 = MagicMock()
|
||||
text_part2.text = "world"
|
||||
|
||||
# Create refusal part (no text attribute)
|
||||
refusal_part = MagicMock()
|
||||
refusal_part.type = "refusal"
|
||||
refusal_part.refusal = "Blocked"
|
||||
del refusal_part.text # Remove text attribute
|
||||
|
||||
content = [text_part1, refusal_part, text_part2]
|
||||
result = extract_text_content(content)
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_extract_text_content_empty_list(responses_impl):
|
||||
"""Test extraction from empty list returns None."""
|
||||
content = []
|
||||
result = extract_text_content(content)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_extract_text_content_no_text_parts(responses_impl):
|
||||
"""Test extraction with no text parts returns None."""
|
||||
# Create image part (no text attribute)
|
||||
image_part = MagicMock()
|
||||
image_part.type = "image"
|
||||
image_part.image_url = "http://example.com"
|
||||
|
||||
# Create refusal part (no text attribute)
|
||||
refusal_part = MagicMock()
|
||||
refusal_part.type = "refusal"
|
||||
refusal_part.refusal = "Blocked"
|
||||
|
||||
# Explicitly remove text attributes to simulate non-text parts
|
||||
if hasattr(image_part, "text"):
|
||||
delattr(image_part, "text")
|
||||
if hasattr(refusal_part, "text"):
|
||||
delattr(refusal_part, "text")
|
||||
|
||||
content = [image_part, refusal_part]
|
||||
result = extract_text_content(content)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_extract_text_content_none_input(responses_impl):
|
||||
"""Test extraction with None input returns None."""
|
||||
result = extract_text_content(None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Message Conversion Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def test_convert_user_message(responses_impl):
|
||||
"""Test conversion of user message."""
|
||||
openai_msg = MagicMock(role="user", content="Hello world")
|
||||
result = convert_openai_to_inference_messages([openai_msg])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserMessage)
|
||||
assert result[0].content == "Hello world"
|
||||
|
||||
|
||||
def test_convert_system_message(responses_impl):
|
||||
"""Test conversion of system message."""
|
||||
openai_msg = MagicMock(role="system", content="You are helpful")
|
||||
result = convert_openai_to_inference_messages([openai_msg])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], SystemMessage)
|
||||
assert result[0].content == "You are helpful"
|
||||
|
||||
|
||||
def test_convert_assistant_message(responses_impl):
|
||||
"""Test conversion of assistant message."""
|
||||
openai_msg = MagicMock(role="assistant", content="I can help")
|
||||
result = convert_openai_to_inference_messages([openai_msg])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], CompletionMessage)
|
||||
assert result[0].content == "I can help"
|
||||
assert result[0].stop_reason == StopReason.end_of_turn
|
||||
|
||||
|
||||
def test_convert_tool_message_skipped(responses_impl):
|
||||
"""Test that tool messages are skipped."""
|
||||
openai_msg = MagicMock(role="tool", content="Tool result")
|
||||
result = convert_openai_to_inference_messages([openai_msg])
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_convert_complex_content(responses_impl):
|
||||
"""Test conversion with complex content structure."""
|
||||
openai_msg = MagicMock(
|
||||
role="user",
|
||||
content=[
|
||||
MagicMock(text="Analyze this: "),
|
||||
MagicMock(text="important content"),
|
||||
],
|
||||
)
|
||||
result = convert_openai_to_inference_messages([openai_msg])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserMessage)
|
||||
assert result[0].content == "Analyze this: important content"
|
||||
|
||||
|
||||
def test_convert_empty_content_skipped(responses_impl):
|
||||
"""Test that messages with no extractable content are skipped."""
|
||||
openai_msg = MagicMock(role="user", content=[])
|
||||
result = convert_openai_to_inference_messages([openai_msg])
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_convert_assistant_message_dict_format(responses_impl):
|
||||
"""Test conversion of assistant message in dictionary format."""
|
||||
dict_msg = {"role": "assistant", "content": "Violent content refers to media, materials, or expressions"}
|
||||
result = convert_openai_to_inference_messages([dict_msg])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], CompletionMessage)
|
||||
assert result[0].content == "Violent content refers to media, materials, or expressions"
|
||||
assert result[0].stop_reason == StopReason.end_of_turn
|
Loading…
Add table
Add a link
Reference in a new issue