forked from phoenix-oss/llama-stack-mirror
Another attempt at a proper bugfix for safety violations
This commit is contained in:
parent
e5bdd6615a
commit
c9005e95ed
2 changed files with 21 additions and 14 deletions
|
@ -32,14 +32,26 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shields: List[str]
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
) -> None:
|
||||
await asyncio.gather(
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shields
|
||||
for shield_type in shield_types
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shields, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
violation = response.violation
|
||||
if violation.violation_level == ViolationLevel.ERROR:
|
||||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue