From e5bdd6615af32fc8488826b349736fa33f9e676a Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Mon, 23 Sep 2024 18:17:15 -0700 Subject: [PATCH] bug fix for safety violation --- .../impls/meta_reference/agents/agent_instance.py | 2 +- .../providers/impls/meta_reference/agents/safety.py | 12 +----------- .../providers/impls/meta_reference/safety/safety.py | 13 +++++++++++++ 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 0ac26a857..797a1bc7f 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin): # CompletionMessage itself in the ShieldResponse messages.append( CompletionMessage( - content=violation.user_message, + content=step.violation.user_message, stop_reason=StopReason.end_of_turn, ) ) diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index 44d47b16c..e7c982181 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -34,7 +34,7 @@ class ShieldRunnerMixin: async def run_multiple_shields( self, messages: List[Message], shields: List[str] ) -> None: - responses = await asyncio.gather( + await asyncio.gather( *[ self.safety_api.run_shield( shield_type=shield_type, @@ -43,13 +43,3 @@ class ShieldRunnerMixin: for shield_type in shields ] ) - - for shield, r in zip(shields, responses): - if r.violation: - if shield.on_violation_action == OnViolationAction.RAISE: - raise SafetyException(r) - elif shield.on_violation_action == OnViolationAction.WARN: - cprint( - f"[Warn]{shield.__class__.__name__} raised a warning", - color="red", - ) diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 6eccf47a5..e5c42b45c 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -10,6 +10,11 @@ from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException +from llama_stack.providers.impls.meta_reference.safety.shields.base import ( + OnViolationAction, +) + from .config import MetaReferenceShieldType, SafetyConfig from .shields import ( @@ -78,6 +83,14 @@ class MetaReferenceSafetyImpl(Safety): }, ) + if shield.on_violation_action == OnViolationAction.RAISE: + raise SafetyException(violation) + elif shield.on_violation_action == OnViolationAction.WARN: + cprint( + f"[Warn]{shield.__class__.__name__} raised a warning", + color="red", + ) + return RunShieldResponse(violation=violation) def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: