mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
address comments
This commit is contained in:
parent
74be622f57
commit
9a4d3d7576
1 changed files with 5 additions and 13 deletions
|
|
@ -134,16 +134,6 @@ class StreamingResponseOrchestrator:
|
|||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _apply_guardrails(self, text: str, context: str = "content") -> str | None:
|
||||
"""Apply guardrails to text content. Returns violation message if blocked."""
|
||||
if not self.guardrail_ids or not text:
|
||||
return None
|
||||
|
||||
violation_message = await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"{context.capitalize()} guardrail violation: {violation_message}")
|
||||
return violation_message
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
refusal_content = OpenAIResponseContentPartRefusal(refusal=violation_message)
|
||||
|
|
@ -205,8 +195,9 @@ class StreamingResponseOrchestrator:
|
|||
# Input safety validation - check messages before processing
|
||||
if self.guardrail_ids:
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||
input_violation_message = await self._apply_guardrails(combined_text, "input")
|
||||
input_violation_message = await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
if input_violation_message:
|
||||
logger.info(f"Input guardrail violation: {input_violation_message}")
|
||||
# Return refusal response immediately
|
||||
yield await self._create_refusal_response(input_violation_message)
|
||||
return
|
||||
|
|
@ -715,10 +706,11 @@ class StreamingResponseOrchestrator:
|
|||
) + tool_call.function.arguments
|
||||
|
||||
# Output Safety Validation for a chunk
|
||||
if chat_response_content:
|
||||
if self.guardrail_ids:
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await self._apply_guardrails(accumulated_text, "output")
|
||||
violation_message = await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"Output guardrail violation: {violation_message}")
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue