feat: Add responses and safety impl extra_body

This commit is contained in:
Swapna Lekkala 2025-10-10 15:03:34 -07:00
parent 0a96a7faa5
commit ad4362e48d
163 changed files with 29338 additions and 141 deletions

View file

@ -91,6 +91,12 @@ def mock_conversations_api():
return mock_api
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api,
@ -98,6 +104,7 @@ def openai_responses_impl(
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_safety_api,
mock_conversations_api,
):
return OpenAIResponsesImpl(
@ -106,6 +113,7 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
conversations_api=mock_conversations_api,
)

View file

@ -0,0 +1,160 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_shield_ids,
extract_text_content,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
def test_extract_shield_ids_from_strings(responses_impl):
"""Test extraction from simple string shield IDs."""
shields = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_from_objects(responses_impl):
"""Test extraction from ResponseShieldSpec objects."""
shields = [
ResponseShieldSpec(type="llama-guard"),
ResponseShieldSpec(type="content-filter"),
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter"]
def test_extract_shield_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
shields = [
"llama-guard",
ResponseShieldSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_shield_ids(None)
assert result == []
def test_extract_shield_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_shield_ids([])
assert result == []
def test_extract_shield_ids_unknown_format(responses_impl):
"""Test extraction with unknown shield format raises ValueError."""
# Create an object that's neither string nor ResponseShieldSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
shields = ["valid-shield", unknown_object, "another-shield"]
with pytest.raises(ValueError, match="Unknown shield format.*expected str or ResponseShieldSpec"):
extract_shield_ids(shields)
def test_extract_text_content_string(responses_impl):
"""Test extraction from simple string content."""
content = "Hello world"
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_text(responses_impl):
"""Test extraction from list content with text parts."""
content = [
MagicMock(text="Hello "),
MagicMock(text="world"),
]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_refusal(responses_impl):
"""Test extraction skips refusal parts."""
# Create text parts
text_part1 = MagicMock()
text_part1.text = "Hello"
text_part2 = MagicMock()
text_part2.text = "world"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
del refusal_part.text # Remove text attribute
content = [text_part1, refusal_part, text_part2]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_empty_list(responses_impl):
"""Test extraction from empty list returns None."""
content = []
result = extract_text_content(content)
assert result is None
def test_extract_text_content_no_text_parts(responses_impl):
"""Test extraction with no text parts returns None."""
# Create image part (no text attribute)
image_part = MagicMock()
image_part.type = "image"
image_part.image_url = "http://example.com"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
# Explicitly remove text attributes to simulate non-text parts
if hasattr(image_part, "text"):
delattr(image_part, "text")
if hasattr(refusal_part, "text"):
delattr(refusal_part, "text")
content = [image_part, refusal_part]
result = extract_text_content(content)
assert result is None
def test_extract_text_content_none_input(responses_impl):
"""Test extraction with None input returns None."""
result = extract_text_content(None)
assert result is None

View file

@ -4,10 +4,44 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseContentPartRefusal,
OpenAIResponseText,
)
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
StreamingResponseOrchestrator,
convert_tooldef_to_chat_tool,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
@pytest.fixture
def mock_inference_api():
inference_api = AsyncMock()
return inference_api
@pytest.fixture
def mock_context():
context = AsyncMock(spec=ChatCompletionContext)
# Add required attributes that StreamingResponseOrchestrator expects
context.tool_context = AsyncMock()
context.tool_context.previous_tools = {}
context.messages = []
return context
def test_convert_tooldef_to_chat_tool_preserves_items_field():
@ -36,3 +70,89 @@ def test_convert_tooldef_to_chat_tool_preserves_items_field():
assert tags_param["type"] == "array"
assert "items" in tags_param, "items field should be preserved for array parameters"
assert tags_param["items"] == {"type": "string"}
async def test_check_input_safety_no_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
mock_safety_api.run_shield.return_value = mock_response
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
inference_api=mock_inference_api,
ctx=mock_context,
response_id="test_id",
created_at=1234567890,
text=OpenAIResponseText(),
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=shield_ids,
)
result = await orchestrator._check_input_safety(messages)
assert result is None
mock_safety_api.run_shield.assert_called_once_with(shield_id="llama-guard", messages=messages, params={})
async def test_check_input_safety_with_violation(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
mock_safety_api.run_shield.return_value = mock_response
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
inference_api=mock_inference_api,
ctx=mock_context,
response_id="test_id",
created_at=1234567890,
text=OpenAIResponseText(),
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=shield_ids,
)
result = await orchestrator._check_input_safety(messages)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
async def test_check_input_safety_empty_inputs(mock_safety_api, mock_inference_api, mock_context):
"""Test input shield validation with empty inputs."""
# Create orchestrator with safety components
orchestrator = StreamingResponseOrchestrator(
inference_api=mock_inference_api,
ctx=mock_context,
response_id="test_id",
created_at=1234567890,
text=OpenAIResponseText(),
max_infer_iters=5,
tool_executor=AsyncMock(),
safety_api=mock_safety_api,
shield_ids=[],
)
# Test empty shield_ids
result = await orchestrator._check_input_safety([UserMessage(content="test")])
assert result is None
# Test empty messages
orchestrator.shield_ids = ["llama-guard"]
result = await orchestrator._check_input_safety([])
assert result is None