mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
improve user message
This commit is contained in:
parent
b5c08c72a7
commit
f8861bc480
2 changed files with 18 additions and 20 deletions
|
|
@ -69,7 +69,6 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from ..safety import SafetyException
|
||||
from .types import ChatCompletionContext, ChatCompletionResult
|
||||
from .utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
|
|
@ -140,11 +139,10 @@ class StreamingResponseOrchestrator:
|
|||
if not self.guardrail_ids or not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
|
||||
except SafetyException as e:
|
||||
logger.info(f"{context.capitalize()} guardrail violation: {e.violation.user_message}")
|
||||
return e.violation.user_message or f"{context.capitalize()} blocked by safety guardrails"
|
||||
violation_message = await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
|
||||
if violation_message:
|
||||
logger.info(f"{context.capitalize()} guardrail violation: {violation_message}")
|
||||
return violation_message
|
||||
|
||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||
"""Create a refusal response to replace streaming content."""
|
||||
|
|
|
|||
|
|
@ -47,9 +47,7 @@ from llama_stack.apis.inference import (
|
|||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
||||
|
||||
from ..safety import SafetyException
|
||||
from llama_stack.apis.safety import Safety
|
||||
|
||||
|
||||
async def convert_chat_choice_to_response_message(
|
||||
|
|
@ -315,10 +313,10 @@ def is_function_tool_call(
|
|||
return False
|
||||
|
||||
|
||||
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> None:
|
||||
"""Run multiple guardrails against messages and raise SafetyException for violations."""
|
||||
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||
"""Run multiple guardrails against messages and return violation message if blocked."""
|
||||
if not guardrail_ids or not messages:
|
||||
return
|
||||
return None
|
||||
|
||||
# Look up shields to get their provider_resource_id (actual model ID)
|
||||
model_ids = []
|
||||
|
|
@ -335,19 +333,21 @@ async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_i
|
|||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||
|
||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||
|
||||
responses = await asyncio.gather(*guardrail_tasks)
|
||||
|
||||
for response in responses:
|
||||
# Check if any of the results are flagged
|
||||
for result in response.results:
|
||||
if result.flagged:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"categories": result.categories},
|
||||
)
|
||||
raise SafetyException(violation)
|
||||
message = result.user_message or "Content blocked by safety guardrails"
|
||||
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||
|
||||
if flagged_categories:
|
||||
message += f" (flagged for: {', '.join(flagged_categories)})"
|
||||
if violation_type:
|
||||
message += f" (violation type: {', '.join(violation_type)})"
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue