diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index 798393e28..c4e170452 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -69,7 +69,6 @@ from llama_stack.log import get_logger from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.telemetry import tracing -from ..safety import SafetyException from .types import ChatCompletionContext, ChatCompletionResult from .utils import ( convert_chat_choice_to_response_message, @@ -140,11 +139,10 @@ class StreamingResponseOrchestrator: if not self.guardrail_ids or not text: return None - try: - await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids) - except SafetyException as e: - logger.info(f"{context.capitalize()} guardrail violation: {e.violation.user_message}") - return e.violation.user_message or f"{context.capitalize()} blocked by safety guardrails" + violation_message = await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids) + if violation_message: + logger.info(f"{context.capitalize()} guardrail violation: {violation_message}") + return violation_message async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: """Create a refusal response to replace streaming content.""" diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py index 756e8405e..0a7538292 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/utils.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/utils.py @@ -47,9 +47,7 @@ from llama_stack.apis.inference import ( OpenAIToolMessageParam, OpenAIUserMessageParam, ) -from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel - -from ..safety import SafetyException +from llama_stack.apis.safety import Safety async def convert_chat_choice_to_response_message( @@ -315,10 +313,10 @@ def is_function_tool_call( return False -async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> None: - """Run multiple guardrails against messages and raise SafetyException for violations.""" +async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None: + """Run multiple guardrails against messages and return violation message if blocked.""" if not guardrail_ids or not messages: - return + return None # Look up shields to get their provider_resource_id (actual model ID) model_ids = [] @@ -335,19 +333,21 @@ async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_i raise ValueError(f"No shield found with identifier '{guardrail_id}'") guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids] - responses = await asyncio.gather(*guardrail_tasks) for response in responses: - # Check if any of the results are flagged for result in response.results: if result.flagged: - violation = SafetyViolation( - violation_level=ViolationLevel.ERROR, - user_message="Content flagged by moderation", - metadata={"categories": result.categories}, - ) - raise SafetyException(violation) + message = result.user_message or "Content blocked by safety guardrails" + flagged_categories = [cat for cat, flagged in result.categories.items() if flagged] + violation_type = result.metadata.get("violation_type", []) if result.metadata else [] + + if flagged_categories: + message += f" (flagged for: {', '.join(flagged_categories)})" + if violation_type: + message += f" (violation type: {', '.join(violation_type)})" + + return message def extract_guardrail_ids(guardrails: list | None) -> list[str]: