Another attempt at a proper bugfix for safety violations

This commit is contained in:
Ashwin Bharambe 2024-09-23 19:06:30 -07:00
parent e5bdd6615a
commit c9005e95ed
2 changed files with 21 additions and 14 deletions

View file

@ -32,14 +32,26 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], shields: List[str]
self, messages: List[Message], shield_types: List[str]
) -> None:
await asyncio.gather(
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
messages=messages,
)
for shield_type in shields
for shield_type in shield_types
]
)
for shield_type, response in zip(shields, responses):
if not response.violation:
continue
violation = response.violation
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
color="red",
)

View file

@ -10,7 +10,6 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
@ -74,23 +73,19 @@ class MetaReferenceSafetyImpl(Safety):
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation:
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(violation)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return RunShieldResponse(violation=violation)
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: