improve user message

This commit is contained in:
Swapna Lekkala 2025-10-13 15:57:43 -07:00
parent b5c08c72a7
commit f8861bc480
2 changed files with 18 additions and 20 deletions

View file

@ -69,7 +69,6 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.telemetry import tracing
from ..safety import SafetyException
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import (
convert_chat_choice_to_response_message,
@ -140,11 +139,10 @@ class StreamingResponseOrchestrator:
if not self.guardrail_ids or not text:
return None
try:
await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
except SafetyException as e:
logger.info(f"{context.capitalize()} guardrail violation: {e.violation.user_message}")
return e.violation.user_message or f"{context.capitalize()} blocked by safety guardrails"
violation_message = await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
if violation_message:
logger.info(f"{context.capitalize()} guardrail violation: {violation_message}")
return violation_message
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
"""Create a refusal response to replace streaming content."""

View file

@ -47,9 +47,7 @@ from llama_stack.apis.inference import (
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from ..safety import SafetyException
from llama_stack.apis.safety import Safety
async def convert_chat_choice_to_response_message(
@ -315,10 +313,10 @@ def is_function_tool_call(
return False
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> None:
"""Run multiple guardrails against messages and raise SafetyException for violations."""
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
"""Run multiple guardrails against messages and return violation message if blocked."""
if not guardrail_ids or not messages:
return
return None
# Look up shields to get their provider_resource_id (actual model ID)
model_ids = []
@ -335,19 +333,21 @@ async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_i
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
responses = await asyncio.gather(*guardrail_tasks)
for response in responses:
# Check if any of the results are flagged
for result in response.results:
if result.flagged:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message="Content flagged by moderation",
metadata={"categories": result.categories},
)
raise SafetyException(violation)
message = result.user_message or "Content blocked by safety guardrails"
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
if flagged_categories:
message += f" (flagged for: {', '.join(flagged_categories)})"
if violation_type:
message += f" (violation type: {', '.join(violation_type)})"
return message
def extract_guardrail_ids(guardrails: list | None) -> list[str]: