This commit is contained in:
Zain Hasan 2024-09-24 16:21:54 -04:00
commit af1710af75
6 changed files with 56 additions and 26 deletions

View file

@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=violation.user_message,
content=step.violation.user_message,
stop_reason=StopReason.end_of_turn,
)
)

View file

@ -32,7 +32,7 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], shields: List[str]
self, messages: List[Message], shield_types: List[str]
) -> None:
responses = await asyncio.gather(
*[
@ -40,16 +40,18 @@ class ShieldRunnerMixin:
shield_type=shield_type,
messages=messages,
)
for shield_type in shields
for shield_type in shield_types
]
)
for shield_type, response in zip(shield_types, responses):
if not response.violation:
continue
for shield, r in zip(shields, responses):
if r.violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
violation = response.violation
if violation.violation_level == ViolationLevel.ERROR:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
color="red",
)

View file

@ -10,6 +10,10 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
from .config import MetaReferenceShieldType, SafetyConfig
from .shields import (
@ -69,9 +73,13 @@ class MetaReferenceSafetyImpl(Safety):
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation:
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.