Another attempt at a proper bugfix for safety violations

This commit is contained in:
Ashwin Bharambe 2024-09-23 19:06:30 -07:00
parent e5bdd6615a
commit c9005e95ed
2 changed files with 21 additions and 14 deletions

View file

@ -10,7 +10,6 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
@ -74,23 +73,19 @@ class MetaReferenceSafetyImpl(Safety):
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation:
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(violation)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return RunShieldResponse(violation=violation)
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: