mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 04:22:35 +00:00
improve user message
This commit is contained in:
parent
b5c08c72a7
commit
f8861bc480
2 changed files with 18 additions and 20 deletions
|
|
@ -69,7 +69,6 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
|
||||||
from ..safety import SafetyException
|
|
||||||
from .types import ChatCompletionContext, ChatCompletionResult
|
from .types import ChatCompletionContext, ChatCompletionResult
|
||||||
from .utils import (
|
from .utils import (
|
||||||
convert_chat_choice_to_response_message,
|
convert_chat_choice_to_response_message,
|
||||||
|
|
@ -140,11 +139,10 @@ class StreamingResponseOrchestrator:
|
||||||
if not self.guardrail_ids or not text:
|
if not self.guardrail_ids or not text:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
violation_message = await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
|
||||||
await run_multiple_guardrails(self.safety_api, text, self.guardrail_ids)
|
if violation_message:
|
||||||
except SafetyException as e:
|
logger.info(f"{context.capitalize()} guardrail violation: {violation_message}")
|
||||||
logger.info(f"{context.capitalize()} guardrail violation: {e.violation.user_message}")
|
return violation_message
|
||||||
return e.violation.user_message or f"{context.capitalize()} blocked by safety guardrails"
|
|
||||||
|
|
||||||
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
async def _create_refusal_response(self, violation_message: str) -> OpenAIResponseObjectStream:
|
||||||
"""Create a refusal response to replace streaming content."""
|
"""Create a refusal response to replace streaming content."""
|
||||||
|
|
|
||||||
|
|
@ -47,9 +47,7 @@ from llama_stack.apis.inference import (
|
||||||
OpenAIToolMessageParam,
|
OpenAIToolMessageParam,
|
||||||
OpenAIUserMessageParam,
|
OpenAIUserMessageParam,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
|
from llama_stack.apis.safety import Safety
|
||||||
|
|
||||||
from ..safety import SafetyException
|
|
||||||
|
|
||||||
|
|
||||||
async def convert_chat_choice_to_response_message(
|
async def convert_chat_choice_to_response_message(
|
||||||
|
|
@ -315,10 +313,10 @@ def is_function_tool_call(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> None:
|
async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_ids: list[str]) -> str | None:
|
||||||
"""Run multiple guardrails against messages and raise SafetyException for violations."""
|
"""Run multiple guardrails against messages and return violation message if blocked."""
|
||||||
if not guardrail_ids or not messages:
|
if not guardrail_ids or not messages:
|
||||||
return
|
return None
|
||||||
|
|
||||||
# Look up shields to get their provider_resource_id (actual model ID)
|
# Look up shields to get their provider_resource_id (actual model ID)
|
||||||
model_ids = []
|
model_ids = []
|
||||||
|
|
@ -335,19 +333,21 @@ async def run_multiple_guardrails(safety_api: Safety, messages: str, guardrail_i
|
||||||
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
raise ValueError(f"No shield found with identifier '{guardrail_id}'")
|
||||||
|
|
||||||
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
guardrail_tasks = [safety_api.run_moderation(messages, model=model_id) for model_id in model_ids]
|
||||||
|
|
||||||
responses = await asyncio.gather(*guardrail_tasks)
|
responses = await asyncio.gather(*guardrail_tasks)
|
||||||
|
|
||||||
for response in responses:
|
for response in responses:
|
||||||
# Check if any of the results are flagged
|
|
||||||
for result in response.results:
|
for result in response.results:
|
||||||
if result.flagged:
|
if result.flagged:
|
||||||
violation = SafetyViolation(
|
message = result.user_message or "Content blocked by safety guardrails"
|
||||||
violation_level=ViolationLevel.ERROR,
|
flagged_categories = [cat for cat, flagged in result.categories.items() if flagged]
|
||||||
user_message="Content flagged by moderation",
|
violation_type = result.metadata.get("violation_type", []) if result.metadata else []
|
||||||
metadata={"categories": result.categories},
|
|
||||||
)
|
if flagged_categories:
|
||||||
raise SafetyException(violation)
|
message += f" (flagged for: {', '.join(flagged_categories)})"
|
||||||
|
if violation_type:
|
||||||
|
message += f" (violation type: {', '.join(violation_type)})"
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
def extract_guardrail_ids(guardrails: list | None) -> list[str]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue