From 9a4d3d7576bb433a03ac9acba8de768bb3e9fe6a Mon Sep 17 00:00:00 2001 From: Swapna Lekkala Date: Tue, 14 Oct 2025 13:30:42 -0700 Subject: [PATCH] address comments --- .../meta_reference/responses/streaming.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py index c4e170452..c23714617 100644 --- a/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +++ b/llama_stack/providers/inline/agents/meta_reference/responses/streaming.py @@ -134,16 +134,6 @@ class StreamingResponseOrchestrator: # Track if we've sent a refusal response self.violation_detected = False - async def _apply_guardrails(self, text: str, context: str = "content") -> str | None: - """Apply guardrails to text content. Returns violation message if blocked.""" - if not self.guardrail_ids or not text: - return None - - violation_message = await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids) - if violation_message: - logger.info(f"{context.capitalize()} guardrail violation: {violation_message}") - return violation_message - async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream: """Create a refusal response to replace streaming content.""" refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message) @@ -205,8 +195,9 @@ class StreamingResponseOrchestrator: # Input safety validation - check messages before processing if self.guardrail_ids: combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages]) - input_violation_message = await self._apply_guardrails(combined_text, "input") + input_violation_message = await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids) if input_violation_message: + logger.info(f"Input guardrail violation: {input_violation_message}") # Return refusal response immediately yield await self._create_refusal_response(input_violation_message) return @@ -715,10 +706,11 @@ class StreamingResponseOrchestrator: ) + tool_call.function.arguments # Output Safety Validation for a chunk - if chat_response_content: + if self.guardrail_ids: accumulated_text = "".join(chat_response_content) - violation_message = await self._apply_guardrails(accumulated_text, "output") + violation_message = await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids) if violation_message: + logger.info(f"Output guardrail violation: {violation_message}") yield await self._create_refusal_response(violation_message) self.violation_detected = True return