Another attempt at a proper bugfix for safety violations

This commit is contained in:
Ashwin Bharambe 2024-09-23 19:06:30 -07:00
parent e5bdd6615a
commit c9005e95ed
2 changed files with 21 additions and 14 deletions

View file

@ -32,14 +32,26 @@ class ShieldRunnerMixin:
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields( async def run_multiple_shields(
self, messages: List[Message], shields: List[str] self, messages: List[Message], shield_types: List[str]
) -> None: ) -> None:
await asyncio.gather( responses = await asyncio.gather(
*[ *[
self.safety_api.run_shield( self.safety_api.run_shield(
shield_type=shield_type, shield_type=shield_type,
messages=messages, messages=messages,
) )
for shield_type in shields for shield_type in shield_types
] ]
) )
for shield_type, response in zip(shields, responses):
if not response.violation:
continue
violation = response.violation
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
color="red",
)

View file

@ -10,7 +10,6 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
from llama_stack.providers.impls.meta_reference.safety.shields.base import ( from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction, OnViolationAction,
) )
@ -74,23 +73,19 @@ class MetaReferenceSafetyImpl(Safety):
# TODO: we can refactor ShieldBase, etc. to be inline with the API types # TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages) res = await shield.run(messages)
violation = None violation = None
if res.is_violation: if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation( violation = SafetyViolation(
violation_level=ViolationLevel.ERROR, violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message, user_message=res.violation_return_message,
metadata={ metadata={
"violation_type": res.violation_type, "violation_type": res.violation_type,
}, },
) )
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(violation)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: