forked from phoenix-oss/llama-stack-mirror
bug fix for safety violation
This commit is contained in:
parent
70fb70a71c
commit
e5bdd6615a
3 changed files with 15 additions and 12 deletions
|
@ -10,6 +10,11 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
|||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceShieldType, SafetyConfig
|
||||
|
||||
from .shields import (
|
||||
|
@ -78,6 +83,14 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
},
|
||||
)
|
||||
|
||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||
raise SafetyException(violation)
|
||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue