Another attempt at a proper bugfix for safety violations

This commit is contained in:
Ashwin Bharambe 2024-09-23 19:06:30 -07:00
parent e5bdd6615a
commit c9005e95ed
2 changed files with 21 additions and 14 deletions

View file

@ -32,14 +32,26 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], shields: List[str]
self, messages: List[Message], shield_types: List[str]
) -> None:
await asyncio.gather(
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
messages=messages,
)
for shield_type in shields
for shield_type in shield_types
]
)
for shield_type, response in zip(shields, responses):
if not response.violation:
continue
violation = response.violation
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
color="red",
)