feat: Add responses and safety impl with extra body

This commit is contained in:
Swapna Lekkala 2025-10-10 07:12:51 -07:00
parent 548ccff368
commit e09401805f
15 changed files with 877 additions and 9 deletions

View file

@ -8821,6 +8821,25 @@
}
}
},
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal"
},
"refusal": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal"
},
"OpenAIResponseError": {
"type": "object",
"properties": {
@ -9395,6 +9414,23 @@
}
},
"OpenAIResponseOutputMessageContent": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
},
{
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
}
},
"OpenAIResponseOutputMessageContentOutputText": {
"type": "object",
"properties": {
"text": {

View file

@ -6551,6 +6551,20 @@ components:
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
refusal:
type: string
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
OpenAIResponseError:
type: object
properties:
@ -6972,6 +6986,15 @@ components:
mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools'
mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest'
OpenAIResponseOutputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
discriminator:
propertyName: type
mapping:
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
"OpenAIResponseOutputMessageContentOutputText":
type: object
properties:
text:

View file

@ -5858,6 +5858,25 @@
}
}
},
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal"
},
"refusal": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal"
},
"OpenAIResponseInputMessageContent": {
"oneOf": [
{
@ -6001,6 +6020,23 @@
"description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios."
},
"OpenAIResponseOutputMessageContent": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
},
{
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
}
},
"OpenAIResponseOutputMessageContentOutputText": {
"type": "object",
"properties": {
"text": {

View file

@ -4416,6 +4416,20 @@ components:
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
refusal:
type: string
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
OpenAIResponseInputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
@ -4515,6 +4529,15 @@ components:
under one type because the Responses API gives them all the same "type" value,
and there is no way to tell them apart in certain scenarios.
OpenAIResponseOutputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
discriminator:
propertyName: type
mapping:
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
"OpenAIResponseOutputMessageContentOutputText":
type: object
properties:
text:

View file

@ -7867,6 +7867,25 @@
}
}
},
"OpenAIResponseContentPartRefusal": {
"type": "object",
"properties": {
"type": {
"type": "string",
"const": "refusal",
"default": "refusal"
},
"refusal": {
"type": "string"
}
},
"additionalProperties": false,
"required": [
"type",
"refusal"
],
"title": "OpenAIResponseContentPartRefusal"
},
"OpenAIResponseInputMessageContent": {
"oneOf": [
{
@ -8010,6 +8029,23 @@
"description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios."
},
"OpenAIResponseOutputMessageContent": {
"oneOf": [
{
"$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText"
},
{
"$ref": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText",
"refusal": "#/components/schemas/OpenAIResponseContentPartRefusal"
}
}
},
"OpenAIResponseOutputMessageContentOutputText": {
"type": "object",
"properties": {
"text": {

View file

@ -5861,6 +5861,20 @@ components:
url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation'
container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation'
file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath'
OpenAIResponseContentPartRefusal:
type: object
properties:
type:
type: string
const: refusal
default: refusal
refusal:
type: string
additionalProperties: false
required:
- type
- refusal
title: OpenAIResponseContentPartRefusal
OpenAIResponseInputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseInputMessageContentText'
@ -5960,6 +5974,15 @@ components:
under one type because the Responses API gives them all the same "type" value,
and there is no way to tell them apart in certain scenarios.
OpenAIResponseOutputMessageContent:
oneOf:
- $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
- $ref: '#/components/schemas/OpenAIResponseContentPartRefusal'
discriminator:
propertyName: type
mapping:
output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText'
refusal: '#/components/schemas/OpenAIResponseContentPartRefusal'
"OpenAIResponseOutputMessageContentOutputText":
type: object
properties:
text:

View file

@ -131,8 +131,14 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel):
annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list)
@json_schema_type
class OpenAIResponseContentPartRefusal(BaseModel):
type: Literal["refusal"] = "refusal"
refusal: str
OpenAIResponseOutputMessageContent = Annotated[
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal,
Field(discriminator="type"),
]
register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent")

View file

@ -88,6 +88,7 @@ class MetaReferenceAgentsImpl(Agents):
tool_runtime_api=self.tool_runtime_api,
responses_store=self.responses_store,
vector_io_api=self.vector_io_api,
safety_api=self.safety_api,
)
async def create_agent(

View file

@ -15,20 +15,25 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
OpenAIResponseContentPartRefusal,
OpenAIResponseInput,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputTool,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseCreated,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
Inference,
Message,
OpenAIMessageParam,
OpenAISystemMessageParam,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
@ -37,12 +42,16 @@ from llama_stack.providers.utils.responses.responses_store import (
_OpenAIResponseObjectWithInputAndMessages,
)
from ..safety import SafetyException
from .streaming import StreamingResponseOrchestrator
from .tool_executor import ToolExecutor
from .types import ChatCompletionContext, ToolContext
from .utils import (
convert_openai_to_inference_messages,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
extract_shield_ids,
run_multiple_shields,
)
logger = get_logger(name=__name__, category="openai_responses")
@ -61,12 +70,14 @@ class OpenAIResponsesImpl:
tool_runtime_api: ToolRuntime,
responses_store: ResponsesStore,
vector_io_api: VectorIO, # VectorIO
safety_api: Safety,
):
self.inference_api = inference_api
self.tool_groups_api = tool_groups_api
self.tool_runtime_api = tool_runtime_api
self.responses_store = responses_store
self.vector_io_api = vector_io_api
self.safety_api = safety_api
self.tool_executor = ToolExecutor(
tool_groups_api=tool_groups_api,
tool_runtime_api=tool_runtime_api,
@ -217,9 +228,7 @@ class OpenAIResponsesImpl:
stream = bool(stream)
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text
# Shields parameter received via extra_body - not yet implemented
if shields is not None:
raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider")
shield_ids = extract_shield_ids(shields) if shields else []
stream_gen = self._create_streaming_response(
input=input,
@ -231,6 +240,7 @@ class OpenAIResponsesImpl:
text=text,
tools=tools,
max_infer_iters=max_infer_iters,
shield_ids=shield_ids,
)
if stream:
@ -264,6 +274,42 @@ class OpenAIResponsesImpl:
raise ValueError("The response stream never reached a terminal state")
return final_response
async def _check_input_safety(
self, messages: list[Message], shield_ids: list[str]
) -> OpenAIResponseContentPartRefusal | None:
"""Validate input messages against shields. Returns refusal content if violation found."""
try:
await run_multiple_shields(self.safety_api, messages, shield_ids)
except SafetyException as e:
logger.info(f"Input shield violation: {e.violation.user_message}")
return OpenAIResponseContentPartRefusal(
refusal=e.violation.user_message or "Content blocked by safety shields"
)
async def _create_refusal_response_events(
self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str
) -> AsyncIterator[OpenAIResponseObjectStream]:
"""Create and yield refusal response events following the established streaming pattern."""
# Create initial response and yield created event
initial_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="in_progress",
output=[],
)
yield OpenAIResponseObjectStreamResponseCreated(response=initial_response)
# Create completed refusal response using OpenAIResponseContentPartRefusal
refusal_response = OpenAIResponseObject(
id=response_id,
created_at=created_at,
model=model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
async def _create_streaming_response(
self,
input: str | list[OpenAIResponseInput],
@ -275,6 +321,7 @@ class OpenAIResponsesImpl:
text: OpenAIResponseText | None = None,
tools: list[OpenAIResponseInputTool] | None = None,
max_infer_iters: int | None = 10,
shield_ids: list[str] | None = None,
) -> AsyncIterator[OpenAIResponseObjectStream]:
# Input preprocessing
all_input, messages, tool_context = await self._process_input_with_previous_response(
@ -282,8 +329,23 @@ class OpenAIResponsesImpl:
)
await self._prepend_instructions(messages, instructions)
# Input safety validation hook - validates messages before streaming orchestrator starts
if shield_ids:
input_messages = convert_openai_to_inference_messages(messages)
input_refusal = await self._check_input_safety(input_messages, shield_ids)
if input_refusal:
# Return refusal response immediately
response_id = f"resp-{uuid.uuid4()}"
created_at = int(time.time())
async for refusal_event in self._create_refusal_response_events(
input_refusal, response_id, created_at, model
):
yield refusal_event
return
# Structured outputs
response_format = await convert_response_text_to_chat_response_format(text)
response_format = convert_response_text_to_chat_response_format(text)
ctx = ChatCompletionContext(
model=model,
@ -307,8 +369,11 @@ class OpenAIResponsesImpl:
text=text,
max_infer_iters=max_infer_iters,
tool_executor=self.tool_executor,
safety_api=self.safety_api,
shield_ids=shield_ids,
)
# Output safety validation hook - delegated to streaming orchestrator for real-time validation
# Stream the response
final_response = None
failed_response = None

View file

@ -14,9 +14,11 @@ from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseContentPartOutputText,
OpenAIResponseError,
OpenAIResponseContentPartRefusal,
OpenAIResponseInputTool,
OpenAIResponseInputToolMCP,
OpenAIResponseMCPApprovalRequest,
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStream,
OpenAIResponseObjectStreamResponseCompleted,
@ -52,8 +54,14 @@ from llama_stack.apis.inference import (
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
from ..safety import SafetyException
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
from .utils import (
convert_chat_choice_to_response_message,
convert_openai_to_inference_messages,
is_function_tool_call,
run_multiple_shields,
)
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -89,6 +97,8 @@ class StreamingResponseOrchestrator:
text: OpenAIResponseText,
max_infer_iters: int,
tool_executor, # Will be the tool execution logic from the main class
safety_api,
shield_ids: list[str] | None = None,
):
self.inference_api = inference_api
self.ctx = ctx
@ -97,6 +107,8 @@ class StreamingResponseOrchestrator:
self.text = text
self.max_infer_iters = max_infer_iters
self.tool_executor = tool_executor
self.safety_api = safety_api
self.shield_ids = shield_ids or []
self.sequence_number = 0
# Store MCP tool mapping that gets built during tool processing
self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {}
@ -104,6 +116,43 @@ class StreamingResponseOrchestrator:
self.final_messages: list[OpenAIMessageParam] = []
# mapping for annotations
self.citation_files: dict[str, str] = {}
# Track accumulated text for shield validation
self.accumulated_text = ""
# Track if we've sent a refusal response
self.violation_detected = False
async def _check_output_stream_safety(self, text_delta: str) -> str | None:
"""Check streaming text content against shields. Returns violation message if blocked."""
if not self.shield_ids:
return None
self.accumulated_text += text_delta
# Check accumulated text periodically for violations (every 50 characters or at word boundaries)
if len(self.accumulated_text) > 50 or text_delta.endswith((" ", "\n", ".", "!", "?")):
temp_messages = [{"role": "assistant", "content": self.accumulated_text}]
messages = convert_openai_to_inference_messages(temp_messages)
try:
await run_multiple_shields(self.safety_api, messages, self.shield_ids)
except SafetyException as e:
logger.info(f"Output shield violation: {e.violation.user_message}")
return e.violation.user_message or "Generated content blocked by safety shields"
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
# Create a completed refusal response
refusal_response = OpenAIResponseObject(
id=self.response_id,
created_at=self.created_at,
model=self.ctx.model,
status="completed",
output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")],
)
return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response)
def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]:
cloned: list[OpenAIResponseOutput] = []
@ -326,6 +375,15 @@ class StreamingResponseOrchestrator:
for chunk_choice in chunk.choices:
# Emit incremental text content as delta events
if chunk_choice.delta.content:
# Check output stream safety before yielding content
violation_message = await self._check_output_stream_safety(chunk_choice.delta.content)
if violation_message:
# Stop streaming and send refusal response
yield await self._create_refusal_response(violation_message)
self.violation_detected = True
# Return immediately - no further processing needed
return
# Emit content_part.added event for first text chunk
if not content_part_emitted:
content_part_emitted = True

View file

@ -7,6 +7,7 @@
import re
import uuid
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseAnnotationFileCitation,
OpenAIResponseInput,
@ -26,6 +27,8 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseText,
)
from llama_stack.apis.inference import (
CompletionMessage,
Message,
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartParam,
@ -44,7 +47,19 @@ from llama_stack.apis.inference import (
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
StopReason,
SystemMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="openai_responses_utils")
# ============================================================================
# Message and Content Conversion Functions
# ============================================================================
async def convert_chat_choice_to_response_message(
@ -171,7 +186,7 @@ async def convert_response_input_to_chat_messages(
pass
else:
content = await convert_response_content_to_chat_content(input_item.content)
message_type = await get_message_type_by_role(input_item.role)
message_type = get_message_type_by_role(input_item.role)
if message_type is None:
raise ValueError(
f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context"
@ -240,7 +255,8 @@ async def convert_response_text_to_chat_response_format(
raise ValueError(f"Unsupported text format: {text.format}")
async def get_message_type_by_role(role: str):
async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None:
"""Get the appropriate OpenAI message parameter type for a given role."""
role_to_type = {
"user": OpenAIUserMessageParam,
"system": OpenAISystemMessageParam,
@ -307,3 +323,90 @@ def is_function_tool_call(
if t.type == "function" and t.name == tool_call.function.name:
return True
return False
# ============================================================================
# Safety and Shield Validation Functions
# ============================================================================
async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None:
"""Run multiple shields against messages and raise SafetyException for violations."""
if not shield_ids or not messages:
return
for shield_id in shield_ids:
response = await safety_api.run_shield(
shield_id=shield_id,
messages=messages,
params={},
)
if response.violation and response.violation.violation_level.name == "ERROR":
from ..safety import SafetyException
raise SafetyException(response.violation)
def extract_shield_ids(shields: list | None) -> list[str]:
"""Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects."""
if not shields:
return []
shield_ids = []
for shield in shields:
if isinstance(shield, str):
shield_ids.append(shield)
elif isinstance(shield, ResponseShieldSpec):
shield_ids.append(shield.type)
else:
logger.warning(f"Unknown shield format: {shield}")
return shield_ids
def extract_text_content(content: str | list | None) -> str | None:
"""Extract text content from OpenAI message content (string or complex structure)."""
if isinstance(content, str):
return content
elif isinstance(content, list):
# Handle complex content - extract text parts only
text_parts = []
for part in content:
if hasattr(part, "text"):
text_parts.append(part.text)
elif hasattr(part, "type") and part.type == "refusal":
# Skip refusal parts - don't validate them again
continue
return " ".join(text_parts) if text_parts else None
return None
def convert_openai_to_inference_messages(openai_messages: list) -> list[Message]:
"""Convert OpenAI messages to inference API Message format."""
safety_messages = []
for msg in openai_messages:
# Handle both object attributes and dictionary keys
if hasattr(msg, "content") and hasattr(msg, "role"):
text_content = extract_text_content(msg.content)
role = msg.role
elif isinstance(msg, dict) and "content" in msg and "role" in msg:
text_content = extract_text_content(msg["content"])
role = msg["role"]
else:
continue
if text_content:
# Create appropriate message subclass based on role
if role == "user":
safety_messages.append(UserMessage(content=text_content))
elif role == "system":
safety_messages.append(SystemMessage(content=text_content))
elif role == "assistant":
safety_messages.append(
CompletionMessage(
content=text_content,
stop_reason=StopReason.end_of_turn, # Default for safety validation
)
)
# Note: Skip "tool" role messages as they're not typically validated by shields
return safety_messages

View file

@ -247,12 +247,17 @@ class LlamaGuardShield:
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None:
# Check for "unsafe\n<code>" format
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
# Check for direct category code format (e.g., "S1", "S2", etc.)
if re.match(r"^S\d+$", response):
return response
return None
def get_safety_categories(self) -> list[str]:

View file

@ -297,3 +297,130 @@ def test_function_call_output_response_with_none_arguments(openai_client, client
assert response.output[0].type == "function_call"
assert response.output[0].arguments == "{}"
_ = response.output[0].call_id
@pytest.mark.parametrize("stream", [True, False])
def test_input_safety_shields(compat_client, text_model_id, stream):
"""Test that input safety shields block inappropriate content."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
# Test with shield that should block inappropriate content
# Note: This test assumes shields are configured in the test environment
inappropriate_message = "How can I hurt someone?"
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": inappropriate_message,
}
],
stream=stream,
extra_body={"shields": ["llama-guard"]}, # Shields passed via extra_body
)
if stream:
# Input violations cause immediate refusal - only one completion chunk
final_response = None
for chunk in response:
if chunk.type == "response.completed":
final_response = chunk.response
break
assert final_response is not None
response_to_check = final_response
else:
response_to_check = response
# Check if the response contains a refusal (indicating shield blocked content)
assert response_to_check.output[0].type == "message"
message = response_to_check.output[0]
# Shield should ALWAYS block inappropriate input - expect refusal response
assert len(message.content) > 0, "Content should not be empty"
content_item = message.content[0]
assert content_item.type == "refusal", f"Expected refusal content, but got: {content_item.type}"
assert len(content_item.refusal.strip()) > 0, "Refusal message should not be empty"
@pytest.mark.parametrize("stream", [True, False])
def test_output_safety_shields(compat_client, text_model_id, stream):
"""Test that output safety shields can block generated content."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
# Use an ambiguous prompt that could potentially generate content requiring shield evaluation
# This prompt: "Describe self-defense techniques."
# Could result in:
# 1. Educational safety content → shield allows output → normal response
# 2. Violent techniques content → shield blocks output → refusal response
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": "Describe violent content.",
}
],
stream=stream,
extra_body={"shields": ["llama-guard"]}, # Output shield validation
)
if stream:
# For streaming, violations cause stream to end with refusal completion
last_chunk = None
for chunk in response:
last_chunk = chunk
assert last_chunk.type == "response.completed", f"Expected final chunk to be completion, got {last_chunk.type}"
response_to_check = last_chunk.response
else:
response_to_check = response
# Verify we get a proper response (this test mainly verifies the shield integration works)
assert response_to_check.output[0].type == "message"
message = response_to_check.output[0]
assert len(message.content) > 0, "Message should have content"
content_item = message.content[0]
assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}"
def test_shields_with_tools(compat_client, text_model_id):
"""Test that shields work correctly when tools are present."""
if not isinstance(compat_client, OpenAI):
pytest.skip("OpenAI client is required until responses API exists in llama-stack-client")
response = compat_client.responses.create(
model=text_model_id,
input=[
{
"role": "user",
"content": "What's the weather like? Please help me in a safe and appropriate way.",
}
],
tools=[
{
"type": "function",
"name": "get_weather",
"description": "Get the weather in a given city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string", "description": "The city to get the weather for"},
},
},
}
],
extra_body={"shields": ["llama-guard"]},
stream=False,
)
# Verify response completes successfully with tools and shields
assert response.id is not None
assert len(response.output) > 0
# Response should be either a function call or a message
output_type = response.output[0].type
assert output_type in ["function_call", "message"]

View file

@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import (
from llama_stack.apis.agents import Order
from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseContentPartRefusal,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -38,8 +39,11 @@ from llama_stack.apis.inference import (
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.safety import SafetyViolation, ViolationLevel
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
@ -83,9 +87,20 @@ def mock_vector_io_api():
return vector_io_api
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_safety_api,
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
@ -93,6 +108,7 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
)
@ -1066,3 +1082,57 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
model=model,
text=OpenAIResponseText(format={"type": "invalid"}),
)
# ============================================================================
# Shield Validation Tests
# ============================================================================
async def test_check_input_safety_no_violation(openai_responses_impl):
"""Test input shield validation with no violations."""
messages = [UserMessage(content="Hello world")]
shield_ids = ["llama-guard"]
# Mock successful shield validation (no violation)
mock_response = AsyncMock()
mock_response.violation = None
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert result is None
openai_responses_impl.safety_api.run_shield.assert_called_once_with(
shield_id="llama-guard", messages=messages, params={}
)
async def test_check_input_safety_with_violation(openai_responses_impl):
"""Test input shield validation with safety violation."""
messages = [UserMessage(content="Harmful content")]
shield_ids = ["llama-guard"]
# Mock shield violation
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={}
)
mock_response = AsyncMock()
mock_response.violation = violation
openai_responses_impl.safety_api.run_shield.return_value = mock_response
result = await openai_responses_impl._check_input_safety(messages, shield_ids)
assert isinstance(result, OpenAIResponseContentPartRefusal)
assert result.refusal == "Content violates safety guidelines"
assert result.type == "refusal"
async def test_check_input_safety_empty_inputs(openai_responses_impl):
"""Test input shield validation with empty inputs."""
# Test empty shield_ids
result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], [])
assert result is None
# Test empty messages
result = await openai_responses_impl._check_input_safety([], ["llama-guard"])
assert result is None

View file

@ -0,0 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.agents.agents import ResponseShieldSpec
from llama_stack.apis.inference import (
CompletionMessage,
StopReason,
SystemMessage,
UserMessage,
)
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
convert_openai_to_inference_messages,
extract_shield_ids,
extract_text_content,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
# ============================================================================
# Shield ID Extraction Tests
# ============================================================================
def test_extract_shield_ids_from_strings(responses_impl):
"""Test extraction from simple string shield IDs."""
shields = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_from_objects(responses_impl):
"""Test extraction from ResponseShieldSpec objects."""
shields = [
ResponseShieldSpec(type="llama-guard"),
ResponseShieldSpec(type="content-filter"),
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter"]
def test_extract_shield_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
shields = [
"llama-guard",
ResponseShieldSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_shield_ids(shields)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_shield_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_shield_ids(None)
assert result == []
def test_extract_shield_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_shield_ids([])
assert result == []
def test_extract_shield_ids_unknown_format(responses_impl, caplog):
"""Test extraction with unknown shield format logs warning."""
# Create an object that's neither string nor ResponseShieldSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec
shields = ["valid-shield", unknown_object, "another-shield"]
result = extract_shield_ids(shields)
assert result == ["valid-shield", "another-shield"]
assert "Unknown shield format" in caplog.text
# ============================================================================
# Text Content Extraction Tests
# ============================================================================
def test_extract_text_content_string(responses_impl):
"""Test extraction from simple string content."""
content = "Hello world"
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_text(responses_impl):
"""Test extraction from list content with text parts."""
content = [
MagicMock(text="Hello "),
MagicMock(text="world"),
]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_list_with_refusal(responses_impl):
"""Test extraction skips refusal parts."""
# Create text parts
text_part1 = MagicMock()
text_part1.text = "Hello"
text_part2 = MagicMock()
text_part2.text = "world"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
del refusal_part.text # Remove text attribute
content = [text_part1, refusal_part, text_part2]
result = extract_text_content(content)
assert result == "Hello world"
def test_extract_text_content_empty_list(responses_impl):
"""Test extraction from empty list returns None."""
content = []
result = extract_text_content(content)
assert result is None
def test_extract_text_content_no_text_parts(responses_impl):
"""Test extraction with no text parts returns None."""
# Create image part (no text attribute)
image_part = MagicMock()
image_part.type = "image"
image_part.image_url = "http://example.com"
# Create refusal part (no text attribute)
refusal_part = MagicMock()
refusal_part.type = "refusal"
refusal_part.refusal = "Blocked"
# Explicitly remove text attributes to simulate non-text parts
if hasattr(image_part, "text"):
delattr(image_part, "text")
if hasattr(refusal_part, "text"):
delattr(refusal_part, "text")
content = [image_part, refusal_part]
result = extract_text_content(content)
assert result is None
def test_extract_text_content_none_input(responses_impl):
"""Test extraction with None input returns None."""
result = extract_text_content(None)
assert result is None
# ============================================================================
# Message Conversion Tests
# ============================================================================
def test_convert_user_message(responses_impl):
"""Test conversion of user message."""
openai_msg = MagicMock(role="user", content="Hello world")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], UserMessage)
assert result[0].content == "Hello world"
def test_convert_system_message(responses_impl):
"""Test conversion of system message."""
openai_msg = MagicMock(role="system", content="You are helpful")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], SystemMessage)
assert result[0].content == "You are helpful"
def test_convert_assistant_message(responses_impl):
"""Test conversion of assistant message."""
openai_msg = MagicMock(role="assistant", content="I can help")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], CompletionMessage)
assert result[0].content == "I can help"
assert result[0].stop_reason == StopReason.end_of_turn
def test_convert_tool_message_skipped(responses_impl):
"""Test that tool messages are skipped."""
openai_msg = MagicMock(role="tool", content="Tool result")
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 0
def test_convert_complex_content(responses_impl):
"""Test conversion with complex content structure."""
openai_msg = MagicMock(
role="user",
content=[
MagicMock(text="Analyze this: "),
MagicMock(text="important content"),
],
)
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 1
assert isinstance(result[0], UserMessage)
assert result[0].content == "Analyze this: important content"
def test_convert_empty_content_skipped(responses_impl):
"""Test that messages with no extractable content are skipped."""
openai_msg = MagicMock(role="user", content=[])
result = convert_openai_to_inference_messages([openai_msg])
assert len(result) == 0
def test_convert_assistant_message_dict_format(responses_impl):
"""Test conversion of assistant message in dictionary format."""
dict_msg = {"role": "assistant", "content": "Violent content refers to media, materials, or expressions"}
result = convert_openai_to_inference_messages([dict_msg])
assert len(result) == 1
assert isinstance(result[0], CompletionMessage)
assert result[0].content == "Violent content refers to media, materials, or expressions"
assert result[0].stop_reason == StopReason.end_of_turn