From e09401805f74e05d8623247a6e5ef40fb5d24cdd Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Fri, 10 Oct 2025 07:12:51 -0700 Subject: [PATCH] feat: Add responses and safety impl with extra body --- docs/static/deprecated-llama-stack-spec.html | 36 +++ docs/static/deprecated-llama-stack-spec.yaml | 23 ++ docs/static/llama-stack-spec.html | 36 +++ docs/static/llama-stack-spec.yaml | 23 ++ docs/static/stainless-llama-stack-spec.html | 36 +++ docs/static/stainless-llama-stack-spec.yaml | 23 ++ llama_stack/apis/agents/openai_responses.py | 8 +- .../inline/agents/meta_reference/agents.py | 1 + .../responses/openai_responses.py | 73 ++++- .../meta_reference/responses/streaming.py | 60 +++- .../agents/meta_reference/responses/utils.py | 107 +++++++- .../inline/safety/llama_guard/llama_guard.py | 5 + .../agents/test_openai_responses.py | 127 +++++++++ .../meta_reference/test_openai_responses.py | 72 ++++- .../test_responses_safety_utils.py | 256 ++++++++++++++++++ 15 files changed, 877 insertions(+), 9 deletions(-) create mode 100644 tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py diff --git a/docs/static/deprecated-llama-stack-spec.html b/docs/static/deprecated-llama-stack-spec.html index 2fa339eeb..54ce9c2a8 100644 --- a/docs/static/deprecated-llama-stack-spec.html +++ b/docs/static/deprecated-llama-stack-spec.html @@ -8821,6 +8821,25 @@ } } }, + "OpenAIResponseContentPartRefusal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "refusal", + "default": "refusal" + }, + "refusal": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "refusal" + ], + "title": "OpenAIResponseContentPartRefusal" + }, "OpenAIResponseError": { "type": "object", "properties": { @@ -9395,6 +9414,23 @@ } }, "OpenAIResponseOutputMessageContent": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText", + "refusal": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + } + }, + "OpenAIResponseOutputMessageContentOutputText": { "type": "object", "properties": { "text": { diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 98af89fa8..a8c121fa0 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -6551,6 +6551,20 @@ components: url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation' container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation' file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath' + OpenAIResponseContentPartRefusal: + type: object + properties: + type: + type: string + const: refusal + default: refusal + refusal: + type: string + additionalProperties: false + required: + - type + - refusal + title: OpenAIResponseContentPartRefusal OpenAIResponseError: type: object properties: @@ -6972,6 +6986,15 @@ components: mcp_list_tools: '#/components/schemas/OpenAIResponseOutputMessageMCPListTools' mcp_approval_request: '#/components/schemas/OpenAIResponseMCPApprovalRequest' OpenAIResponseOutputMessageContent: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText' + - $ref: '#/components/schemas/OpenAIResponseContentPartRefusal' + discriminator: + propertyName: type + mapping: + output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText' + refusal: '#/components/schemas/OpenAIResponseContentPartRefusal' + "OpenAIResponseOutputMessageContentOutputText": type: object properties: text: diff --git a/docs/static/llama-stack-spec.html b/docs/static/llama-stack-spec.html index 1064c1433..8dc82c4cc 100644 --- a/docs/static/llama-stack-spec.html +++ b/docs/static/llama-stack-spec.html @@ -5858,6 +5858,25 @@ } } }, + "OpenAIResponseContentPartRefusal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "refusal", + "default": "refusal" + }, + "refusal": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "refusal" + ], + "title": "OpenAIResponseContentPartRefusal" + }, "OpenAIResponseInputMessageContent": { "oneOf": [ { @@ -6001,6 +6020,23 @@ "description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios." }, "OpenAIResponseOutputMessageContent": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText", + "refusal": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + } + }, + "OpenAIResponseOutputMessageContentOutputText": { "type": "object", "properties": { "text": { diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index f36d69e3a..96db79a7e 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -4416,6 +4416,20 @@ components: url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation' container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation' file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath' + OpenAIResponseContentPartRefusal: + type: object + properties: + type: + type: string + const: refusal + default: refusal + refusal: + type: string + additionalProperties: false + required: + - type + - refusal + title: OpenAIResponseContentPartRefusal OpenAIResponseInputMessageContent: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputMessageContentText' @@ -4515,6 +4529,15 @@ components: under one type because the Responses API gives them all the same "type" value, and there is no way to tell them apart in certain scenarios. OpenAIResponseOutputMessageContent: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText' + - $ref: '#/components/schemas/OpenAIResponseContentPartRefusal' + discriminator: + propertyName: type + mapping: + output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText' + refusal: '#/components/schemas/OpenAIResponseContentPartRefusal' + "OpenAIResponseOutputMessageContentOutputText": type: object properties: text: diff --git a/docs/static/stainless-llama-stack-spec.html b/docs/static/stainless-llama-stack-spec.html index 25fa2bc03..4c21d43be 100644 --- a/docs/static/stainless-llama-stack-spec.html +++ b/docs/static/stainless-llama-stack-spec.html @@ -7867,6 +7867,25 @@ } } }, + "OpenAIResponseContentPartRefusal": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "refusal", + "default": "refusal" + }, + "refusal": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "refusal" + ], + "title": "OpenAIResponseContentPartRefusal" + }, "OpenAIResponseInputMessageContent": { "oneOf": [ { @@ -8010,6 +8029,23 @@ "description": "Corresponds to the various Message types in the Responses API. They are all under one type because the Responses API gives them all the same \"type\" value, and there is no way to tell them apart in certain scenarios." }, "OpenAIResponseOutputMessageContent": { + "oneOf": [ + { + "$ref": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText" + }, + { + "$ref": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "output_text": "#/components/schemas/OpenAIResponseOutputMessageContentOutputText", + "refusal": "#/components/schemas/OpenAIResponseContentPartRefusal" + } + } + }, + "OpenAIResponseOutputMessageContentOutputText": { "type": "object", "properties": { "text": { diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index df0112be7..11b35d42b 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -5861,6 +5861,20 @@ components: url_citation: '#/components/schemas/OpenAIResponseAnnotationCitation' container_file_citation: '#/components/schemas/OpenAIResponseAnnotationContainerFileCitation' file_path: '#/components/schemas/OpenAIResponseAnnotationFilePath' + OpenAIResponseContentPartRefusal: + type: object + properties: + type: + type: string + const: refusal + default: refusal + refusal: + type: string + additionalProperties: false + required: + - type + - refusal + title: OpenAIResponseContentPartRefusal OpenAIResponseInputMessageContent: oneOf: - $ref: '#/components/schemas/OpenAIResponseInputMessageContentText' @@ -5960,6 +5974,15 @@ components: under one type because the Responses API gives them all the same "type" value, and there is no way to tell them apart in certain scenarios. OpenAIResponseOutputMessageContent: + oneOf: + - $ref: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText' + - $ref: '#/components/schemas/OpenAIResponseContentPartRefusal' + discriminator: + propertyName: type + mapping: + output_text: '#/components/schemas/OpenAIResponseOutputMessageContentOutputText' + refusal: '#/components/schemas/OpenAIResponseContentPartRefusal' + "OpenAIResponseOutputMessageContentOutputText": type: object properties: text: diff --git a/llama_stack/apis/agents/openai_responses.py b/llama_stack/apis/agents/openai_responses.py index a1ce134b6..19e6fd0f9 100644 --- a/llama_stack/apis/agents/openai_responses.py +++ b/llama_stack/apis/agents/openai_responses.py @@ -131,8 +131,14 @@ class OpenAIResponseOutputMessageContentOutputText(BaseModel): annotations: list[OpenAIResponseAnnotations] = Field(default_factory=list) +@json_schema_type +class OpenAIResponseContentPartRefusal(BaseModel): + type: Literal["refusal"] = "refusal" + refusal: str + + OpenAIResponseOutputMessageContent = Annotated[ - OpenAIResponseOutputMessageContentOutputText, + OpenAIResponseOutputMessageContentOutputText | OpenAIResponseContentPartRefusal, Field(discriminator="type"), ] register_schema(OpenAIResponseOutputMessageContent, name="OpenAIResponseOutputMessageContent") diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index cfaf56a34..55e221c9f 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -88,6 +88,7 @@ class MetaReferenceAgentsImpl(Agents): tool_runtime_api=self.tool_runtime_api, responses_store=self.responses_store, vector_io_api=self.vector_io_api, + safety_api=self.safety_api, ) async def create_agent( diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py index fabe46f43..83a321eae 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py @@ -15,20 +15,25 @@ from llama_stack.apis.agents.openai_responses import ( ListOpenAIResponseInputItem, ListOpenAIResponseObject, OpenAIDeleteResponseObject, + OpenAIResponseContentPartRefusal, OpenAIResponseInput, OpenAIResponseInputMessageContentText, OpenAIResponseInputTool, OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, + OpenAIResponseObjectStreamResponseCompleted, + OpenAIResponseObjectStreamResponseCreated, OpenAIResponseText, OpenAIResponseTextFormat, ) from llama_stack.apis.inference import ( Inference, + Message, OpenAIMessageParam, OpenAISystemMessageParam, ) +from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.log import get_logger @@ -37,12 +42,16 @@ from llama_stack.providers.utils.responses.responses_store import ( _OpenAIResponseObjectWithInputAndMessages, ) +from ..safety import SafetyException from .streaming import StreamingResponseOrchestrator from .tool_executor import ToolExecutor from .types import ChatCompletionContext, ToolContext from .utils import ( + convert_openai_to_inference_messages, convert_response_input_to_chat_messages, convert_response_text_to_chat_response_format, + extract_shield_ids, + run_multiple_shields, ) logger = get_logger(name=__name__, category="openai_responses") @@ -61,12 +70,14 @@ class OpenAIResponsesImpl: tool_runtime_api: ToolRuntime, responses_store: ResponsesStore, vector_io_api: VectorIO, # VectorIO + safety_api: Safety, ): self.inference_api = inference_api self.tool_groups_api = tool_groups_api self.tool_runtime_api = tool_runtime_api self.responses_store = responses_store self.vector_io_api = vector_io_api + self.safety_api = safety_api self.tool_executor = ToolExecutor( tool_groups_api=tool_groups_api, tool_runtime_api=tool_runtime_api, @@ -217,9 +228,7 @@ class OpenAIResponsesImpl: stream = bool(stream) text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")) if text is None else text - # Shields parameter received via extra_body - not yet implemented - if shields is not None: - raise NotImplementedError("Shields parameter is not yet implemented in the meta-reference provider") + shield_ids = extract_shield_ids(shields) if shields else [] stream_gen = self._create_streaming_response( input=input, @@ -231,6 +240,7 @@ class OpenAIResponsesImpl: text=text, tools=tools, max_infer_iters=max_infer_iters, + shield_ids=shield_ids, ) if stream: @@ -264,6 +274,42 @@ class OpenAIResponsesImpl: raise ValueError("The response stream never reached a terminal state") return final_response + async def _check_input_safety( + self, messages: list[Message], shield_ids: list[str] + ) -> OpenAIResponseContentPartRefusal | None: + """Validate input messages against shields. Returns refusal content if violation found.""" + try: + await run_multiple_shields(self.safety_api, messages, shield_ids) + except SafetyException as e: + logger.info(f"Input shield violation: {e.violation.user_message}") + return OpenAIResponseContentPartRefusal( + refusal=e.violation.user_message or "Content blocked by safety shields" + ) + + async def _create_refusal_response_events( + self, refusal_content: OpenAIResponseContentPartRefusal, response_id: str, created_at: int, model: str + ) -> AsyncIterator[OpenAIResponseObjectStream]: + """Create and yield refusal response events following the established streaming pattern.""" + # Create initial response and yield created event + initial_response = OpenAIResponseObject( + id=response_id, + created_at=created_at, + model=model, + status="in_progress", + output=[], + ) + yield OpenAIResponseObjectStreamResponseCreated(response=initial_response) + + # Create completed refusal response using OpenAIResponseContentPartRefusal + refusal_response = OpenAIResponseObject( + id=response_id, + created_at=created_at, + model=model, + status="completed", + output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")], + ) + yield OpenAIResponseObjectStreamResponseCompleted(response=refusal_response) + async def _create_streaming_response( self, input: str | list[OpenAIResponseInput], @@ -275,6 +321,7 @@ class OpenAIResponsesImpl: text: OpenAIResponseText | None = None, tools: list[OpenAIResponseInputTool] | None = None, max_infer_iters: int | None = 10, + shield_ids: list[str] | None = None, ) -> AsyncIterator[OpenAIResponseObjectStream]: # Input preprocessing all_input, messages, tool_context = await self._process_input_with_previous_response( @@ -282,8 +329,23 @@ class OpenAIResponsesImpl: ) await self._prepend_instructions(messages, instructions) + # Input safety validation hook - validates messages before streaming orchestrator starts + if shield_ids: + input_messages = convert_openai_to_inference_messages(messages) + input_refusal = await self._check_input_safety(input_messages, shield_ids) + if input_refusal: + # Return refusal response immediately + response_id = f"resp-{uuid.uuid4()}" + created_at = int(time.time()) + + async for refusal_event in self._create_refusal_response_events( + input_refusal, response_id, created_at, model + ): + yield refusal_event + return + # Structured outputs - response_format = await convert_response_text_to_chat_response_format(text) + response_format = convert_response_text_to_chat_response_format(text) ctx = ChatCompletionContext( model=model, @@ -307,8 +369,11 @@ class OpenAIResponsesImpl: text=text, max_infer_iters=max_infer_iters, tool_executor=self.tool_executor, + safety_api=self.safety_api, + shield_ids=shield_ids, ) + # Output safety validation hook - delegated to streaming orchestrator for real-time validation # Stream the response final_response = None failed_response = None diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index e4f2e7228..c95a2e732 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -14,9 +14,11 @@ from llama_stack.apis.agents.openai_responses import ( MCPListToolsTool, OpenAIResponseContentPartOutputText, OpenAIResponseError, + OpenAIResponseContentPartRefusal, OpenAIResponseInputTool, OpenAIResponseInputToolMCP, OpenAIResponseMCPApprovalRequest, + OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, @@ -52,8 +54,14 @@ from llama_stack.apis.inference import ( from llama_stack.log import get_logger from llama_stack.providers.utils.telemetry import tracing +from ..safety import SafetyException from .types import ChatCompletionContext, ChatCompletionResult -from .utils import convert_chat_choice_to_response_message, is_function_tool_call +from .utils import ( + convert_chat_choice_to_response_message, + convert_openai_to_inference_messages, + is_function_tool_call, + run_multiple_shields, +) logger = get_logger(name=__name__, category="agents::meta_reference") @@ -89,6 +97,8 @@ class StreamingResponseOrchestrator: text: OpenAIResponseText, max_infer_iters: int, tool_executor, # Will be the tool execution logic from the main class + safety_api, + shield_ids: list[str] | None = None, ): self.inference_api = inference_api self.ctx = ctx @@ -97,6 +107,8 @@ class StreamingResponseOrchestrator: self.text = text self.max_infer_iters = max_infer_iters self.tool_executor = tool_executor + self.safety_api = safety_api + self.shield_ids = shield_ids or [] self.sequence_number = 0 # Store MCP tool mapping that gets built during tool processing self.mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] = ctx.tool_context.previous_tools or {} @@ -104,6 +116,43 @@ class StreamingResponseOrchestrator: self.final_messages: list[OpenAIMessageParam] = [] # mapping for annotations self.citation_files: dict[str, str] = {} + # Track accumulated text for shield validation + self.accumulated_text = "" + # Track if we've sent a refusal response + self.violation_detected = False + + async def _check_output_stream_safety(self, text_delta: str) -> str | None: + """Check streaming text content against shields. Returns violation message if blocked.""" + if not self.shield_ids: + return None + + self.accumulated_text += text_delta + + # Check accumulated text periodically for violations (every 50 characters or at word boundaries) + if len(self.accumulated_text) > 50 or text_delta.endswith((" ", "\n", ".", "!", "?")): + temp_messages = [{"role": "assistant", "content": self.accumulated_text}] + messages = convert_openai_to_inference_messages(temp_messages) + + try: + await run_multiple_shields(self.safety_api, messages, self.shield_ids) + except SafetyException as e: + logger.info(f"Output shield violation: {e.violation.user_message}") + return e.violation.user_message or "Generated content blocked by safety shields" + + async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: + """Create a refusal response to replace streaming content.""" + refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message) + + # Create a completed refusal response + refusal_response = OpenAIResponseObject( + id=self.response_id, + created_at=self.created_at, + model=self.ctx.model, + status="completed", + output=[OpenAIResponseMessage(role="assistant", content=[refusal_content], type="message")], + ) + + return OpenAIResponseObjectStreamResponseCompleted(response=refusal_response) def _clone_outputs(self, outputs: list[OpenAIResponseOutput]) -> list[OpenAIResponseOutput]: cloned: list[OpenAIResponseOutput] = [] @@ -326,6 +375,15 @@ class StreamingResponseOrchestrator: for chunk_choice in chunk.choices: # Emit incremental text content as delta events if chunk_choice.delta.content: + # Check output stream safety before yielding content + violation_message = await self._check_output_stream_safety(chunk_choice.delta.content) + if violation_message: + # Stop streaming and send refusal response + yield await self._create_refusal_response(violation_message) + self.violation_detected = True + # Return immediately - no further processing needed + return + # Emit content_part.added event for first text chunk if not content_part_emitted: content_part_emitted = True diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index e67e9bdca..5ce917834 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -7,6 +7,7 @@ import re import uuid +from llama_stack.apis.agents.agents import ResponseShieldSpec from llama_stack.apis.agents.openai_responses import ( OpenAIResponseAnnotationFileCitation, OpenAIResponseInput, @@ -26,6 +27,8 @@ from llama_stack.apis.agents.openai_responses import ( OpenAIResponseText, ) from llama_stack.apis.inference import ( + CompletionMessage, + Message, OpenAIAssistantMessageParam, OpenAIChatCompletionContentPartImageParam, OpenAIChatCompletionContentPartParam, @@ -44,7 +47,19 @@ from llama_stack.apis.inference import ( OpenAISystemMessageParam, OpenAIToolMessageParam, OpenAIUserMessageParam, + StopReason, + SystemMessage, + UserMessage, ) +from llama_stack.apis.safety import Safety +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="openai_responses_utils") + + +# ============================================================================ +# Message and Content Conversion Functions +# ============================================================================ async def convert_chat_choice_to_response_message( @@ -171,7 +186,7 @@ async def convert_response_input_to_chat_messages( pass else: content = await convert_response_content_to_chat_content(input_item.content) - message_type = await get_message_type_by_role(input_item.role) + message_type = get_message_type_by_role(input_item.role) if message_type is None: raise ValueError( f"Llama Stack OpenAI Responses does not yet support message role '{input_item.role}' in this context" @@ -240,7 +255,8 @@ async def convert_response_text_to_chat_response_format( raise ValueError(f"Unsupported text format: {text.format}") -async def get_message_type_by_role(role: str): +async def get_message_type_by_role(role: str) -> type[OpenAIMessageParam] | None: + """Get the appropriate OpenAI message parameter type for a given role.""" role_to_type = { "user": OpenAIUserMessageParam, "system": OpenAISystemMessageParam, @@ -307,3 +323,90 @@ def is_function_tool_call( if t.type == "function" and t.name == tool_call.function.name: return True return False + + +# ============================================================================ +# Safety and Shield Validation Functions +# ============================================================================ + + +async def run_multiple_shields(safety_api: Safety, messages: list[Message], shield_ids: list[str]) -> None: + """Run multiple shields against messages and raise SafetyException for violations.""" + if not shield_ids or not messages: + return + + for shield_id in shield_ids: + response = await safety_api.run_shield( + shield_id=shield_id, + messages=messages, + params={}, + ) + if response.violation and response.violation.violation_level.name == "ERROR": + from ..safety import SafetyException + + raise SafetyException(response.violation) + + +def extract_shield_ids(shields: list | None) -> list[str]: + """Extract shield IDs from shields parameter, handling both string IDs and ResponseShieldSpec objects.""" + if not shields: + return [] + + shield_ids = [] + for shield in shields: + if isinstance(shield, str): + shield_ids.append(shield) + elif isinstance(shield, ResponseShieldSpec): + shield_ids.append(shield.type) + else: + logger.warning(f"Unknown shield format: {shield}") + + return shield_ids + + +def extract_text_content(content: str | list | None) -> str | None: + """Extract text content from OpenAI message content (string or complex structure).""" + if isinstance(content, str): + return content + elif isinstance(content, list): + # Handle complex content - extract text parts only + text_parts = [] + for part in content: + if hasattr(part, "text"): + text_parts.append(part.text) + elif hasattr(part, "type") and part.type == "refusal": + # Skip refusal parts - don't validate them again + continue + return " ".join(text_parts) if text_parts else None + return None + + +def convert_openai_to_inference_messages(openai_messages: list) -> list[Message]: + """Convert OpenAI messages to inference API Message format.""" + safety_messages = [] + for msg in openai_messages: + # Handle both object attributes and dictionary keys + if hasattr(msg, "content") and hasattr(msg, "role"): + text_content = extract_text_content(msg.content) + role = msg.role + elif isinstance(msg, dict) and "content" in msg and "role" in msg: + text_content = extract_text_content(msg["content"]) + role = msg["role"] + else: + continue + + if text_content: + # Create appropriate message subclass based on role + if role == "user": + safety_messages.append(UserMessage(content=text_content)) + elif role == "system": + safety_messages.append(SystemMessage(content=text_content)) + elif role == "assistant": + safety_messages.append( + CompletionMessage( + content=text_content, + stop_reason=StopReason.end_of_turn, # Default for safety validation + ) + ) + # Note: Skip "tool" role messages as they're not typically validated by shields + return safety_messages diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index 206182343..3bf9bcca4 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -247,12 +247,17 @@ class LlamaGuardShield: self.safety_categories = safety_categories def check_unsafe_response(self, response: str) -> str | None: + # Check for "unsafe\n" format match = re.match(r"^unsafe\n(.*)$", response) if match: # extracts the unsafe code extracted = match.group(1) return extracted + # Check for direct category code format (e.g., "S1", "S2", etc.) + if re.match(r"^S\d+$", response): + return response + return None def get_safety_categories(self) -> list[str]: diff --git a/tests/integration/agents/test_openai_responses.py b/tests/integration/agents/test_openai_responses.py index 6648257e6..ad9dcc920 100644 --- a/tests/integration/agents/test_openai_responses.py +++ b/tests/integration/agents/test_openai_responses.py @@ -297,3 +297,130 @@ def test_function_call_output_response_with_none_arguments(openai_client, client assert response.output[0].type == "function_call" assert response.output[0].arguments == "{}" _ = response.output[0].call_id + + +@pytest.mark.parametrize("stream", [True, False]) +def test_input_safety_shields(compat_client, text_model_id, stream): + """Test that input safety shields block inappropriate content.""" + if not isinstance(compat_client, OpenAI): + pytest.skip("OpenAI client is required until responses API exists in llama-stack-client") + + # Test with shield that should block inappropriate content + # Note: This test assumes shields are configured in the test environment + inappropriate_message = "How can I hurt someone?" + + response = compat_client.responses.create( + model=text_model_id, + input=[ + { + "role": "user", + "content": inappropriate_message, + } + ], + stream=stream, + extra_body={"shields": ["llama-guard"]}, # Shields passed via extra_body + ) + + if stream: + # Input violations cause immediate refusal - only one completion chunk + final_response = None + for chunk in response: + if chunk.type == "response.completed": + final_response = chunk.response + break + + assert final_response is not None + response_to_check = final_response + else: + response_to_check = response + + # Check if the response contains a refusal (indicating shield blocked content) + assert response_to_check.output[0].type == "message" + message = response_to_check.output[0] + + # Shield should ALWAYS block inappropriate input - expect refusal response + assert len(message.content) > 0, "Content should not be empty" + content_item = message.content[0] + assert content_item.type == "refusal", f"Expected refusal content, but got: {content_item.type}" + assert len(content_item.refusal.strip()) > 0, "Refusal message should not be empty" + + +@pytest.mark.parametrize("stream", [True, False]) +def test_output_safety_shields(compat_client, text_model_id, stream): + """Test that output safety shields can block generated content.""" + if not isinstance(compat_client, OpenAI): + pytest.skip("OpenAI client is required until responses API exists in llama-stack-client") + + # Use an ambiguous prompt that could potentially generate content requiring shield evaluation + # This prompt: "Describe self-defense techniques." + # Could result in: + # 1. Educational safety content → shield allows output → normal response + # 2. Violent techniques content → shield blocks output → refusal response + response = compat_client.responses.create( + model=text_model_id, + input=[ + { + "role": "user", + "content": "Describe violent content.", + } + ], + stream=stream, + extra_body={"shields": ["llama-guard"]}, # Output shield validation + ) + + if stream: + # For streaming, violations cause stream to end with refusal completion + last_chunk = None + for chunk in response: + last_chunk = chunk + + assert last_chunk.type == "response.completed", f"Expected final chunk to be completion, got {last_chunk.type}" + response_to_check = last_chunk.response + else: + response_to_check = response + # Verify we get a proper response (this test mainly verifies the shield integration works) + assert response_to_check.output[0].type == "message" + message = response_to_check.output[0] + + assert len(message.content) > 0, "Message should have content" + content_item = message.content[0] + assert content_item.type == "refusal", f"Content type should be 'refusal', got {content_item.type}" + + +def test_shields_with_tools(compat_client, text_model_id): + """Test that shields work correctly when tools are present.""" + if not isinstance(compat_client, OpenAI): + pytest.skip("OpenAI client is required until responses API exists in llama-stack-client") + + response = compat_client.responses.create( + model=text_model_id, + input=[ + { + "role": "user", + "content": "What's the weather like? Please help me in a safe and appropriate way.", + } + ], + tools=[ + { + "type": "function", + "name": "get_weather", + "description": "Get the weather in a given city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city to get the weather for"}, + }, + }, + } + ], + extra_body={"shields": ["llama-guard"]}, + stream=False, + ) + + # Verify response completes successfully with tools and shields + assert response.id is not None + assert len(response.output) > 0 + + # Response should be either a function call or a message + output_type = response.output[0].type + assert output_type in ["function_call", "message"] diff --git a/tests/unit/providers/agents/meta_reference/test_openai_responses.py b/tests/unit/providers/agents/meta_reference/test_openai_responses.py index 2ff586a08..072061192 100644 --- a/tests/unit/providers/agents/meta_reference/test_openai_responses.py +++ b/tests/unit/providers/agents/meta_reference/test_openai_responses.py @@ -18,6 +18,7 @@ from openai.types.chat.chat_completion_chunk import ( from llama_stack.apis.agents import Order from llama_stack.apis.agents.openai_responses import ( ListOpenAIResponseInputItem, + OpenAIResponseContentPartRefusal, OpenAIResponseInputMessageContentText, OpenAIResponseInputToolFunction, OpenAIResponseInputToolMCP, @@ -38,8 +39,11 @@ from llama_stack.apis.inference import ( OpenAIResponseFormatJSONObject, OpenAIResponseFormatJSONSchema, OpenAIUserMessageParam, + UserMessage, ) from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.safety import SafetyViolation, ViolationLevel +from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.core.access_control.access_control import default_policy from llama_stack.core.datatypes import ResponsesStoreConfig from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( @@ -83,9 +87,20 @@ def mock_vector_io_api(): return vector_io_api +@pytest.fixture +def mock_safety_api(): + safety_api = AsyncMock() + return safety_api + + @pytest.fixture def openai_responses_impl( - mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api + mock_inference_api, + mock_tool_groups_api, + mock_tool_runtime_api, + mock_responses_store, + mock_vector_io_api, + mock_safety_api, ): return OpenAIResponsesImpl( inference_api=mock_inference_api, @@ -93,6 +108,7 @@ def openai_responses_impl( tool_runtime_api=mock_tool_runtime_api, responses_store=mock_responses_store, vector_io_api=mock_vector_io_api, + safety_api=mock_safety_api, ) @@ -1066,3 +1082,57 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_ model=model, text=OpenAIResponseText(format={"type": "invalid"}), ) + + +# ============================================================================ +# Shield Validation Tests +# ============================================================================ + + +async def test_check_input_safety_no_violation(openai_responses_impl): + """Test input shield validation with no violations.""" + messages = [UserMessage(content="Hello world")] + shield_ids = ["llama-guard"] + + # Mock successful shield validation (no violation) + mock_response = AsyncMock() + mock_response.violation = None + openai_responses_impl.safety_api.run_shield.return_value = mock_response + + result = await openai_responses_impl._check_input_safety(messages, shield_ids) + + assert result is None + openai_responses_impl.safety_api.run_shield.assert_called_once_with( + shield_id="llama-guard", messages=messages, params={} + ) + + +async def test_check_input_safety_with_violation(openai_responses_impl): + """Test input shield validation with safety violation.""" + messages = [UserMessage(content="Harmful content")] + shield_ids = ["llama-guard"] + + # Mock shield violation + violation = SafetyViolation( + violation_level=ViolationLevel.ERROR, user_message="Content violates safety guidelines", metadata={} + ) + mock_response = AsyncMock() + mock_response.violation = violation + openai_responses_impl.safety_api.run_shield.return_value = mock_response + + result = await openai_responses_impl._check_input_safety(messages, shield_ids) + + assert isinstance(result, OpenAIResponseContentPartRefusal) + assert result.refusal == "Content violates safety guidelines" + assert result.type == "refusal" + + +async def test_check_input_safety_empty_inputs(openai_responses_impl): + """Test input shield validation with empty inputs.""" + # Test empty shield_ids + result = await openai_responses_impl._check_input_safety([UserMessage(content="test")], []) + assert result is None + + # Test empty messages + result = await openai_responses_impl._check_input_safety([], ["llama-guard"]) + assert result is None diff --git a/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py new file mode 100644 index 000000000..9c0fac8e3 --- /dev/null +++ b/tests/unit/providers/agents/meta_reference/test_responses_safety_utils.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from llama_stack.apis.agents.agents import ResponseShieldSpec +from llama_stack.apis.inference import ( + CompletionMessage, + StopReason, + SystemMessage, + UserMessage, +) +from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import ( + OpenAIResponsesImpl, +) +from llama_stack.providers.inline.agents.meta_reference.responses.utils import ( + convert_openai_to_inference_messages, + extract_shield_ids, + extract_text_content, +) + + +@pytest.fixture +def mock_apis(): + """Create mock APIs for testing.""" + return { + "inference_api": AsyncMock(), + "tool_groups_api": AsyncMock(), + "tool_runtime_api": AsyncMock(), + "responses_store": AsyncMock(), + "vector_io_api": AsyncMock(), + "safety_api": AsyncMock(), + } + + +@pytest.fixture +def responses_impl(mock_apis): + """Create OpenAIResponsesImpl instance with mocked dependencies.""" + return OpenAIResponsesImpl(**mock_apis) + + +# ============================================================================ +# Shield ID Extraction Tests +# ============================================================================ + + +def test_extract_shield_ids_from_strings(responses_impl): + """Test extraction from simple string shield IDs.""" + shields = ["llama-guard", "content-filter", "nsfw-detector"] + result = extract_shield_ids(shields) + assert result == ["llama-guard", "content-filter", "nsfw-detector"] + + +def test_extract_shield_ids_from_objects(responses_impl): + """Test extraction from ResponseShieldSpec objects.""" + shields = [ + ResponseShieldSpec(type="llama-guard"), + ResponseShieldSpec(type="content-filter"), + ] + result = extract_shield_ids(shields) + assert result == ["llama-guard", "content-filter"] + + +def test_extract_shield_ids_mixed_formats(responses_impl): + """Test extraction from mixed string and object formats.""" + shields = [ + "llama-guard", + ResponseShieldSpec(type="content-filter"), + "nsfw-detector", + ] + result = extract_shield_ids(shields) + assert result == ["llama-guard", "content-filter", "nsfw-detector"] + + +def test_extract_shield_ids_none_input(responses_impl): + """Test extraction with None input.""" + result = extract_shield_ids(None) + assert result == [] + + +def test_extract_shield_ids_empty_list(responses_impl): + """Test extraction with empty list.""" + result = extract_shield_ids([]) + assert result == [] + + +def test_extract_shield_ids_unknown_format(responses_impl, caplog): + """Test extraction with unknown shield format logs warning.""" + # Create an object that's neither string nor ResponseShieldSpec + unknown_object = {"invalid": "format"} # Plain dict, not ResponseShieldSpec + shields = ["valid-shield", unknown_object, "another-shield"] + result = extract_shield_ids(shields) + assert result == ["valid-shield", "another-shield"] + assert "Unknown shield format" in caplog.text + + +# ============================================================================ +# Text Content Extraction Tests +# ============================================================================ + + +def test_extract_text_content_string(responses_impl): + """Test extraction from simple string content.""" + content = "Hello world" + result = extract_text_content(content) + assert result == "Hello world" + + +def test_extract_text_content_list_with_text(responses_impl): + """Test extraction from list content with text parts.""" + content = [ + MagicMock(text="Hello "), + MagicMock(text="world"), + ] + result = extract_text_content(content) + assert result == "Hello world" + + +def test_extract_text_content_list_with_refusal(responses_impl): + """Test extraction skips refusal parts.""" + # Create text parts + text_part1 = MagicMock() + text_part1.text = "Hello" + + text_part2 = MagicMock() + text_part2.text = "world" + + # Create refusal part (no text attribute) + refusal_part = MagicMock() + refusal_part.type = "refusal" + refusal_part.refusal = "Blocked" + del refusal_part.text # Remove text attribute + + content = [text_part1, refusal_part, text_part2] + result = extract_text_content(content) + assert result == "Hello world" + + +def test_extract_text_content_empty_list(responses_impl): + """Test extraction from empty list returns None.""" + content = [] + result = extract_text_content(content) + assert result is None + + +def test_extract_text_content_no_text_parts(responses_impl): + """Test extraction with no text parts returns None.""" + # Create image part (no text attribute) + image_part = MagicMock() + image_part.type = "image" + image_part.image_url = "http://example.com" + + # Create refusal part (no text attribute) + refusal_part = MagicMock() + refusal_part.type = "refusal" + refusal_part.refusal = "Blocked" + + # Explicitly remove text attributes to simulate non-text parts + if hasattr(image_part, "text"): + delattr(image_part, "text") + if hasattr(refusal_part, "text"): + delattr(refusal_part, "text") + + content = [image_part, refusal_part] + result = extract_text_content(content) + assert result is None + + +def test_extract_text_content_none_input(responses_impl): + """Test extraction with None input returns None.""" + result = extract_text_content(None) + assert result is None + + +# ============================================================================ +# Message Conversion Tests +# ============================================================================ + + +def test_convert_user_message(responses_impl): + """Test conversion of user message.""" + openai_msg = MagicMock(role="user", content="Hello world") + result = convert_openai_to_inference_messages([openai_msg]) + + assert len(result) == 1 + assert isinstance(result[0], UserMessage) + assert result[0].content == "Hello world" + + +def test_convert_system_message(responses_impl): + """Test conversion of system message.""" + openai_msg = MagicMock(role="system", content="You are helpful") + result = convert_openai_to_inference_messages([openai_msg]) + + assert len(result) == 1 + assert isinstance(result[0], SystemMessage) + assert result[0].content == "You are helpful" + + +def test_convert_assistant_message(responses_impl): + """Test conversion of assistant message.""" + openai_msg = MagicMock(role="assistant", content="I can help") + result = convert_openai_to_inference_messages([openai_msg]) + + assert len(result) == 1 + assert isinstance(result[0], CompletionMessage) + assert result[0].content == "I can help" + assert result[0].stop_reason == StopReason.end_of_turn + + +def test_convert_tool_message_skipped(responses_impl): + """Test that tool messages are skipped.""" + openai_msg = MagicMock(role="tool", content="Tool result") + result = convert_openai_to_inference_messages([openai_msg]) + + assert len(result) == 0 + + +def test_convert_complex_content(responses_impl): + """Test conversion with complex content structure.""" + openai_msg = MagicMock( + role="user", + content=[ + MagicMock(text="Analyze this: "), + MagicMock(text="important content"), + ], + ) + result = convert_openai_to_inference_messages([openai_msg]) + + assert len(result) == 1 + assert isinstance(result[0], UserMessage) + assert result[0].content == "Analyze this: important content" + + +def test_convert_empty_content_skipped(responses_impl): + """Test that messages with no extractable content are skipped.""" + openai_msg = MagicMock(role="user", content=[]) + result = convert_openai_to_inference_messages([openai_msg]) + + assert len(result) == 0 + + +def test_convert_assistant_message_dict_format(responses_impl): + """Test conversion of assistant message in dictionary format.""" + dict_msg = {"role": "assistant", "content": "Violent content refers to media, materials, or expressions"} + result = convert_openai_to_inference_messages([dict_msg]) + + assert len(result) == 1 + assert isinstance(result[0], CompletionMessage) + assert result[0].content == "Violent content refers to media, materials, or expressions" + assert result[0].stop_reason == StopReason.end_of_turn