feat: Add responses and safety impl with extra body

This commit is contained in:
Swapna Lekkala 2025-10-10 07:12:51 -07:00
parent 548ccff368
commit e09401805f
15 changed files with 877 additions and 9 deletions

View file

@ -297,3 +297,130 @@ def test_function_call_output_response_with_none_arguments(openai_client, client
assert response.output[0].type == "function_call"
assert response.output[0].arguments == "{}"
_ = response.output[0].call_id
@pytest.mark.parametrize("stream", [True, False])
def test_input_safety_shields(compat_client, text_model_id, stream):
"""Test that input safety shields block inappropriate content."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
# Test with shield that should block inappropriate content
# Note: This test assumes shields are configured in the test environment
inappropriate_message = "How can I hurt someone?"
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": inappropriate_message,
}
],
stream=stream,
extra_body={"shields": ["llama-guard"]}, # Shields passed via extra_body
)
if stream:
# Input violations cause immediate refusal - only one completion chunk
final_response = None
for chunk in response:
if chunk.type == "response.completed":
final_response = chunk.response
break
assert final_response is not None
response_to_check = final_response
else:
response_to_check = response
# Check if the response contains a refusal (indicating shield blocked content)
assert response_to_check.output[0].type == "message"
message = response_to_check.output[0]
# Shield should ALWAYS block inappropriate input - expect refusal response
assert len(message.content) > 0, "Content should not be empty"
content_item = message.content[0]
assert content_item.type == "refusal", f"Expected refusal content, but got: {content_item.type}"
assert len(content_item.refusal.strip()) > 0, "Refusal message should not be empty"
@pytest.mark.parametrize("stream", [True, False])
def test_output_safety_shields(compat_client, text_model_id, stream):
"""Test that output safety shields can block generated content."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
# Use an ambiguous prompt that could potentially generate content requiring shield evaluation
# This prompt: "Describe self-defense techniques."
# Could result in:
# 1. Educational safety content → shield allows output → normal response
# 2. Violent techniques content → shield blocks output → refusal response
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": "Describe violent content.",
}
],
stream=stream,
extra_body={"shields": ["llama-guard"]}, # Output shield validation
)
if stream:
# For streaming, violations cause stream to end with refusal completion
last_chunk = None
for chunk in response:
last_chunk = chunk
assert last_chunk.type == "response.completed", f"Expected final chunk to be completion, got {last_chunk.type}"
response_to_check = last_chunk.response
else:
response_to_check = response
# Verify we get a proper response (this test mainly verifies the shield integration works)
assert response_to_check.output[0].type == "message"
message = response_to_check.output[0]
assert len(message.content) > 0, "Message should have content"
content_item = message.content[0]
assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}"
def test_shields_with_tools(compat_client, text_model_id):
"""Test that shields work correctly when tools are present."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": "What's the weather like? Please help me in a safe and appropriate way.",
}
],
tools=[
{
"type": "function",
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city to get the weather for"},
},
},
}
],
extra_body={"shields": ["llama-guard"]},
stream=False,
)
# Verify response completes successfully with tools and shields
assert response.id is not None
assert len(response.output) > 0
# Response should be either a function call or a message
output_type = response.output[0].type
assert output_type in ["function_call", "message"]

View file

@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import (
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseContentPartRefusal,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -38,8 +39,11 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
@ -83,9 +87,20 @@ def mock_vector_io_api():
return vector_io_api
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_safety_api,
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
@ -93,6 +108,7 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
)
@ -1066,3 +1082,57 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
model=model,
text=OpenAIResponseText(format={"type": "invalid"}),
)
# ============================================================================
# Shield Validation Tests
# ============================================================================
async def test_check_input_safety_no_violation(openai_responses_impl):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert result is None
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
shield_id="llama-guard", messages=messages, params={}
)
async def test_check_input_safety_with_violation(openai_responses_impl):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
assert result.type == "refusal"
async def test_check_input_safety_empty_inputs(openai_responses_impl):
"""Test input shield validation with empty inputs."""
# Test empty shield_ids
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
assert result is None
# Test empty messages
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
assert result is None

View file

@ -0,0 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.apis.inference import (
CompletionMessage,
StopReason,
SystemMessage,
UserMessage,
)
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
convert_openai_to_inference_messages,
extract_shield_ids,
extract_text_content,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
# ============================================================================
# Shield ID Extraction Tests
# ============================================================================
def test_extract_shield_ids_from_strings(responses_impl):
"""Test extraction from simple string shield IDs."""
shields = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_from_objects(responses_impl):
"""Test extraction from ResponseShieldSpec objects."""
shields = [
ResponseShieldSpec(type="llama-guard"),
ResponseShieldSpec(type="content-filter"),
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter"]
def test_extract_shield_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
shields = [
"llama-guard",
ResponseShieldSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_shield_ids(None)
assert result == []
def test_extract_shield_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_shield_ids([])
assert result == []
def test_extract_shield_ids_unknown_format(responses_impl, caplog):
"""Test extraction with unknown shield format logs warning."""
# Create an object that's neither string nor ResponseShieldSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
shields = ["valid-shield", unknown_object, "another-shield"]
result = extract_shield_ids(shields)
assert result == ["valid-shield", "another-shield"]
assert "Unknown shield format" in caplog.text
# ============================================================================
# Text Content Extraction Tests
# ============================================================================
def test_extract_text_content_string(responses_impl):
"""Test extraction from simple string content."""
content = "Hello world"
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_text(responses_impl):
"""Test extraction from list content with text parts."""
content = [
MagicMock(text="Hello "),
MagicMock(text="world"),
]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_refusal(responses_impl):
"""Test extraction skips refusal parts."""
# Create text parts
text_part1 = MagicMock()
text_part1.text = "Hello"
text_part2 = MagicMock()
text_part2.text = "world"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
del refusal_part.text # Remove text attribute
content = [text_part1, refusal_part, text_part2]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_empty_list(responses_impl):
"""Test extraction from empty list returns None."""
content = []
result = extract_text_content(content)
assert result is None
def test_extract_text_content_no_text_parts(responses_impl):
"""Test extraction with no text parts returns None."""
# Create image part (no text attribute)
image_part = MagicMock()
image_part.type = "image"
image_part.image_url = "http://example.com"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
# Explicitly remove text attributes to simulate non-text parts
if hasattr(image_part, "text"):
delattr(image_part, "text")
if hasattr(refusal_part, "text"):
delattr(refusal_part, "text")
content = [image_part, refusal_part]
result = extract_text_content(content)
assert result is None
def test_extract_text_content_none_input(responses_impl):
"""Test extraction with None input returns None."""
result = extract_text_content(None)
assert result is None
# ============================================================================
# Message Conversion Tests
# ============================================================================
def test_convert_user_message(responses_impl):
"""Test conversion of user message."""
openai_msg = MagicMock(role="user", content="Hello world")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], UserMessage)
assert result[0].content == "Hello world"
def test_convert_system_message(responses_impl):
"""Test conversion of system message."""
openai_msg = MagicMock(role="system", content="You are helpful")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], SystemMessage)
assert result[0].content == "You are helpful"
def test_convert_assistant_message(responses_impl):
"""Test conversion of assistant message."""
openai_msg = MagicMock(role="assistant", content="I can help")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], CompletionMessage)
assert result[0].content == "I can help"
assert result[0].stop_reason == StopReason.end_of_turn
def test_convert_tool_message_skipped(responses_impl):
"""Test that tool messages are skipped."""
openai_msg = MagicMock(role="tool", content="Tool result")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 0
def test_convert_complex_content(responses_impl):
"""Test conversion with complex content structure."""
openai_msg = MagicMock(
role="user",
content=[
MagicMock(text="Analyze this: "),
MagicMock(text="important content"),
],
)
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], UserMessage)
assert result[0].content == "Analyze this: important content"
def test_convert_empty_content_skipped(responses_impl):
"""Test that messages with no extractable content are skipped."""
openai_msg = MagicMock(role="user", content=[])
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 0
def test_convert_assistant_message_dict_format(responses_impl):
"""Test conversion of assistant message in dictionary format."""
dict_msg = {"role": "assistant", "content": "Violent content refers to media, materials, or expressions"}
result = convert_openai_to_inference_messages([dict_msg])
assert len(result) == 1
assert isinstance(result[0], CompletionMessage)
assert result[0].content == "Violent content refers to media, materials, or expressions"
assert result[0].stop_reason == StopReason.end_of_turn