diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 398a6245d..5ddfd374a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -91,6 +91,7 @@ class MetaReferenceAgentsImpl(Agents): tool_runtime_api=self.tool_runtime_api, responses_store=self.responses_store, vector_io_api=self.vector_io_api, + safety_api=self.safety_api, conversations_api=self.conversations_api, safety_api=self.safety_api, ) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index 56f9d6e04..36908987b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -34,6 +34,7 @@ from llama_stack.apis.conversations import Conversations from llama_stack.apis.conversations.conversations import ConversationItem from llama_stack.apis.inference import ( Inference, + Message, OpenAIMessageParam, OpenAISystemMessageParam, ) @@ -46,10 +47,12 @@ from llama_stack.providers.utils.responses.responses_store import ( _OpenAIResponseObjectWithInputAndMessages, ) +from ..safety import SafetyException from .streaming import StreamingResponseOrchestrator from .tool_executor import ToolExecutor from .types import ChatCompletionContext, ToolContext from .utils import ( + convert_openai_to_inference_messages, convert_response_input_to_chat_messages, convert_response_text_to_chat_response_format, extract_shield_ids, @@ -71,6 +74,7 @@ class OpenAIResponsesImpl: tool_runtime_api: ToolRuntime, responses_store: ResponsesStore, vector_io_api: VectorIO, # VectorIO + safety_api: Safety, conversations_api: Conversations, safety_api: Safety, ): @@ -79,6 +83,7 @@ class OpenAIResponsesImpl: self.tool_runtime_api = tool_runtime_api self.responses_store = responses_store self.vector_io_api = vector_io_api + self.safety_api = safety_api self.conversations_api = conversations_api self.safety_api = safety_api self.tool_executor = ToolExecutor( @@ -339,6 +344,21 @@ class OpenAIResponsesImpl: ) await self._prepend_instructions(messages, instructions) + # Input safety validation hook - validates messages before streaming orchestrator starts + if shield_ids: + input_messages = convert_openai_to_inference_messages(messages) + input_refusal = await self._check_input_safety(input_messages, shield_ids) + if input_refusal: + # Return refusal response immediately + response_id = f"resp-{uuid.uuid4()}" + created_at = int(time.time()) + + async for refusal_event in self._create_refusal_response_events( + input_refusal, response_id, created_at, model + ): + yield refusal_event + return + # Structured outputs response_format = await convert_response_text_to_chat_response_format(text) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index c30c3d6a2..f8a43ce58 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -467,6 +467,15 @@ class StreamingResponseOrchestrator: for chunk_choice in chunk.choices: # Emit incremental text content as delta events if chunk_choice.delta.content: + # Check output stream safety before yielding content + violation_message = await self._check_output_stream_safety(chunk_choice.delta.content) + if violation_message: + # Stop streaming and send refusal response + yield await self._create_refusal_response(violation_message) + self.violation_detected = True + # Return immediately - no further processing needed + return + # Emit content_part.added event for first text chunk if not content_part_emitted: content_part_emitted = True diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 00ec52443..84800e85a 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -47,6 +47,9 @@ from llama_stack.apis.inference import ( OpenAISystemMessageParam, OpenAIToolMessageParam, OpenAIUserMessageParam, + StopReason, + SystemMessage, + UserMessage, ) from llama_stack.apis.safety import Safety @@ -244,7 +247,7 @@ async def convert_response_text_to_chat_response_format( raise ValueError(f"Unsupported text format: {text.format}") -def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None: +async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None: """Get the appropriate OpenAI message parameter type for a given role.""" role_to_type = { "user": OpenAIUserMessageParam, diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 2f8e4521b..0a46c1c44 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import ( from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( ListOpenAIResponseInputItem, + OpenAIResponseContentPartRefusal, OpenAIResponseInputMessageContentText, OpenAIResponseInputToolFunction, OpenAIResponseInputToolMCP, @@ -38,7 +39,9 @@ from llama_stack.apis.inference import ( OpenAIResponseFormatJSONObject, OpenAIResponseFormatJSONSchema, OpenAIUserMessageParam, + UserMessage, ) +from llama_stack.apis.safety import SafetyViolation, ViolationLevel from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.datatypes import ResponsesStoreConfig @@ -84,6 +87,9 @@ def mock_vector_io_api(): @pytest.fixture +def mock_safety_api(): + safety_api = AsyncMock() + return safety_api def mock_conversations_api(): """Mock conversations API for testing.""" mock_api = AsyncMock() @@ -103,6 +109,7 @@ def openai_responses_impl( mock_tool_runtime_api, mock_responses_store, mock_vector_io_api, + mock_safety_api, mock_conversations_api, mock_safety_api, ): @@ -112,6 +119,7 @@ def openai_responses_impl( tool_runtime_api=mock_tool_runtime_api, responses_store=mock_responses_store, vector_io_api=mock_vector_io_api, + safety_api=mock_safety_api, conversations_api=mock_conversations_api, safety_api=mock_safety_api, ) @@ -1090,3 +1098,52 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_ model=model, text=OpenAIResponseText(format={"type": "invalid"}), ) + + +async def test_check_input_safety_no_violation(openai_responses_impl): + """Test input shield validation with no violations.""" + messages = [UserMessage(content="Hello world")] + shield_ids = ["llama-guard"] + + # Mock successful shield validation (no violation) + mock_response = AsyncMock() + mock_response.violation = None + openai_responses_impl.safety_api.run_shield.return_value = mock_response + + result = await openai_responses_impl._check_input_safety(messages, shield_ids) + + assert result is None + openai_responses_impl.safety_api.run_shield.assert_called_once_with( + shield_id="llama-guard", messages=messages, params={} + ) + + +async def test_check_input_safety_with_violation(openai_responses_impl): + """Test input shield validation with safety violation.""" + messages = [UserMessage(content="Harmful content")] + shield_ids = ["llama-guard"] + + # Mock shield violation + violation = SafetyViolation( + violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={} + ) + mock_response = AsyncMock() + mock_response.violation = violation + openai_responses_impl.safety_api.run_shield.return_value = mock_response + + result = await openai_responses_impl._check_input_safety(messages, shield_ids) + + assert isinstance(result, OpenAIResponseContentPartRefusal) + assert result.refusal == "Content violates safety guidelines" + assert result.type == "refusal" + + +async def test_check_input_safety_empty_inputs(openai_responses_impl): + """Test input shield validation with empty inputs.""" + # Test empty shield_ids + result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], []) + assert result is None + + # Test empty messages + result = await openai_responses_impl._check_input_safety([], ["llama-guard"]) + assert result is None