bug fix for safety violation

This commit is contained in:
Xi Yan 2024-09-23 18:17:15 -07:00
parent 70fb70a71c
commit e5bdd6615a
3 changed files with 15 additions and 12 deletions

View file

@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
# CompletionMessage itself in the ShieldResponse # CompletionMessage itself in the ShieldResponse
messages.append( messages.append(
CompletionMessage( CompletionMessage(
content=violation.user_message, content=step.violation.user_message,
stop_reason=StopReason.end_of_turn, stop_reason=StopReason.end_of_turn,
) )
) )

View file

@ -34,7 +34,7 @@ class ShieldRunnerMixin:
async def run_multiple_shields( async def run_multiple_shields(
self, messages: List[Message], shields: List[str] self, messages: List[Message], shields: List[str]
) -> None: ) -> None:
responses = await asyncio.gather( await asyncio.gather(
*[ *[
self.safety_api.run_shield( self.safety_api.run_shield(
shield_type=shield_type, shield_type=shield_type,
@ -43,13 +43,3 @@ class ShieldRunnerMixin:
for shield_type in shields for shield_type in shields
] ]
) )
for shield, r in zip(shields, responses):
if r.violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)

View file

@ -10,6 +10,11 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
OnViolationAction,
)
from .config import MetaReferenceShieldType, SafetyConfig from .config import MetaReferenceShieldType, SafetyConfig
from .shields import ( from .shields import (
@ -78,6 +83,14 @@ class MetaReferenceSafetyImpl(Safety):
}, },
) )
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(violation)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase: