From c9005e95ed602bca74c348b5251c78ce5d3e362c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 23 Sep 2024 19:06:30 -0700 Subject: [PATCH] Another attempt at a proper bugfix for safety violations --- .../impls/meta_reference/agents/safety.py | 18 +++++++++++++++--- .../impls/meta_reference/safety/safety.py | 17 ++++++----------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index e7c982181..b3aa53728 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -32,14 +32,26 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields( - self, messages: List[Message], shields: List[str] + self, messages: List[Message], shield_types: List[str] ) -> None: - await asyncio.gather( + responses = await asyncio.gather( *[ self.safety_api.run_shield( shield_type=shield_type, messages=messages, ) - for shield_type in shields + for shield_type in shield_types ] ) + for shield_type, response in zip(shields, responses): + if not response.violation: + continue + + violation = response.violation + if violation.violation_level == ViolationLevel.ERROR: + raise SafetyException(violation) + elif violation.violation_level == ViolationLevel.WARN: + cprint( + f"[Warn]{shield_type} raised a warning", + color="red", + ) diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index e5c42b45c..6cf8a79d2 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -10,7 +10,6 @@ from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.safety import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException from llama_stack.providers.impls.meta_reference.safety.shields.base import ( OnViolationAction, ) @@ -74,23 +73,19 @@ class MetaReferenceSafetyImpl(Safety): # TODO: we can refactor ShieldBase, etc. to be inline with the API types res = await shield.run(messages) violation = None - if res.is_violation: + if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE: violation = SafetyViolation( - violation_level=ViolationLevel.ERROR, + violation_level=( + ViolationLevel.ERROR + if shield.on_violation_action == OnViolationAction.RAISE + else ViolationLevel.WARN + ), user_message=res.violation_return_message, metadata={ "violation_type": res.violation_type, }, ) - if shield.on_violation_action == OnViolationAction.RAISE: - raise SafetyException(violation) - elif shield.on_violation_action == OnViolationAction.WARN: - cprint( - f"[Warn]{shield.__class__.__name__} raised a warning", - color="red", - ) - return RunShieldResponse(violation=violation) def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: