bug fix for safety violation

This commit is contained in:
Xi Yan 2024-09-23 18:17:15 -07:00
parent 70fb70a71c
commit e5bdd6615a
3 changed files with 15 additions and 12 deletions

View file

@ -34,7 +34,7 @@ class ShieldRunnerMixin:
async def run_multiple_shields(
self, messages: List[Message], shields: List[str]
) -> None:
responses = await asyncio.gather(
await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
@ -43,13 +43,3 @@ class ShieldRunnerMixin:
for shield_type in shields
]
)
for shield, r in zip(shields, responses):
if r.violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)