forked from phoenix-oss/llama-stack-mirror
Another attempt at a proper bugfix for safety violations
This commit is contained in:
parent
e5bdd6615a
commit
c9005e95ed
2 changed files with 21 additions and 14 deletions
|
@ -10,7 +10,6 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
|||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
@ -74,23 +73,19 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||
res = await shield.run(messages)
|
||||
violation = None
|
||||
if res.is_violation:
|
||||
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
|
||||
violation = SafetyViolation(
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
violation_level=(
|
||||
ViolationLevel.ERROR
|
||||
if shield.on_violation_action == OnViolationAction.RAISE
|
||||
else ViolationLevel.WARN
|
||||
),
|
||||
user_message=res.violation_return_message,
|
||||
metadata={
|
||||
"violation_type": res.violation_type,
|
||||
},
|
||||
)
|
||||
|
||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||
raise SafetyException(violation)
|
||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue