mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 22:12:40 +00:00
clean
This commit is contained in:
parent
da07772480
commit
b5c08c72a7
4 changed files with 33 additions and 141 deletions
|
|
@ -64,7 +64,6 @@ from llama_stack.apis.inference import (
|
|||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChoice,
|
||||
OpenAIMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
|
|
@ -136,33 +135,16 @@ class StreamingResponseOrchestrator:
|
|||
# Track if we've sent a refusal response
|
||||
self.violation_detected = False
|
||||
|
||||
async def _check_input_safety(
|
||||
self, messages: list[OpenAIUserMessageParam]
|
||||
) -> OpenAIResponseContentPartRefusal | None:
|
||||
"""Validate input messages against guardrails. Returns refusal content if violation found."""
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in messages])
|
||||
|
||||
if not combined_text:
|
||||
async def _apply_guardrails(self, text: str, context: str = "content") -> str | None:
|
||||
"""Apply guardrails to text content. Returns violation message if blocked."""
|
||||
if not self.guardrail_ids or not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
await run_multiple_guardrails(self.safety_api, combined_text, self.guardrail_ids)
|
||||
await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Input guardrail violation: {e.violation.user_message}")
|
||||
return OpenAIResponseContentPartRefusal(
|
||||
refusal=e.violation.user_message or "Content blocked by safety guardrails"
|
||||
)
|
||||
|
||||
async def _check_output_stream_chunk_safety(self, accumulated_text: str) -> str | None:
|
||||
"""Check accumulated streaming text content against guardrails. Returns violation message if blocked."""
|
||||
if not self.guardrail_ids or not accumulated_text:
|
||||
return None
|
||||
|
||||
try:
|
||||
await run_multiple_guardrails(self.safety_api, accumulated_text, self.guardrail_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"Output guardrail violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or "Generated content blocked by safety guardrails"
|
||||
logger.info(f"{context.capitalize()} guardrail violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or f"{context.capitalize()} blocked by safety guardrails"
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
|
|
@ -224,10 +206,11 @@ class StreamingResponseOrchestrator:
|
|||
|
||||
# Input safety validation - check messages before processing
|
||||
if self.guardrail_ids:
|
||||
input_refusal = await self._check_input_safety(self.ctx.messages)
|
||||
if input_refusal:
|
||||
combined_text = interleaved_content_as_str([msg.content for msg in self.ctx.messages])
|
||||
input_violation_message = await self._apply_guardrails(combined_text, "input")
|
||||
if input_violation_message:
|
||||
# Return refusal response immediately
|
||||
yield await self._create_refusal_response(input_refusal.refusal)
|
||||
yield await self._create_refusal_response(input_violation_message)
|
||||
return
|
||||
|
||||
async for stream_event in self._process_tools(output_messages):
|
||||
|
|
@ -733,10 +716,10 @@ class StreamingResponseOrchestrator:
|
|||
response_tool_call.function.arguments or ""
|
||||
) + tool_call.function.arguments
|
||||
|
||||
# Safety check after processing all choices in this chunk
|
||||
# Output Safety Validation for a chunk
|
||||
if chat_response_content:
|
||||
accumulated_text = "".join(chat_response_content)
|
||||
violation_message = await self._check_output_stream_chunk_safety(accumulated_text)
|
||||
violation_message = await self._apply_guardrails(accumulated_text, "output")
|
||||
if violation_message:
|
||||
yield await self._create_refusal_response(violation_message)
|
||||
self.violation_detected = True
|
||||
|
|
|
|||
|
|
@ -365,20 +365,3 @@ def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
|||
raise ValueError(f"Unknown guardrail format: {guardrail}, expected str or ResponseGuardrailSpec")
|
||||
|
||||
return guardrail_ids
|
||||
|
||||
|
||||
def extract_text_content(content: str | list | None) -> str | None:
|
||||
"""Extract text content from OpenAI message content (string or complex structure)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Handle complex content - extract text parts only
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if hasattr(part, "text"):
|
||||
text_parts.append(part.text)
|
||||
elif hasattr(part, "type") and part.type == "refusal":
|
||||
# Skip refusal parts - don't validate them again
|
||||
continue
|
||||
return " ".join(text_parts) if text_parts else None
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue