mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
Another attempt at a proper bugfix for safety violations
This commit is contained in:
parent
e5bdd6615a
commit
c9005e95ed
2 changed files with 21 additions and 14 deletions
|
@ -32,14 +32,26 @@ class ShieldRunnerMixin:
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(
|
async def run_multiple_shields(
|
||||||
self, messages: List[Message], shields: List[str]
|
self, messages: List[Message], shield_types: List[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
await asyncio.gather(
|
responses = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.safety_api.run_shield(
|
self.safety_api.run_shield(
|
||||||
shield_type=shield_type,
|
shield_type=shield_type,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
for shield_type in shields
|
for shield_type in shield_types
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
for shield_type, response in zip(shields, responses):
|
||||||
|
if not response.violation:
|
||||||
|
continue
|
||||||
|
|
||||||
|
violation = response.violation
|
||||||
|
if violation.violation_level == ViolationLevel.ERROR:
|
||||||
|
raise SafetyException(violation)
|
||||||
|
elif violation.violation_level == ViolationLevel.WARN:
|
||||||
|
cprint(
|
||||||
|
f"[Warn]{shield_type} raised a warning",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
|
@ -10,7 +10,6 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
|
|
||||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
OnViolationAction,
|
OnViolationAction,
|
||||||
)
|
)
|
||||||
|
@ -74,23 +73,19 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||||
res = await shield.run(messages)
|
res = await shield.run(messages)
|
||||||
violation = None
|
violation = None
|
||||||
if res.is_violation:
|
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
|
||||||
violation = SafetyViolation(
|
violation = SafetyViolation(
|
||||||
violation_level=ViolationLevel.ERROR,
|
violation_level=(
|
||||||
|
ViolationLevel.ERROR
|
||||||
|
if shield.on_violation_action == OnViolationAction.RAISE
|
||||||
|
else ViolationLevel.WARN
|
||||||
|
),
|
||||||
user_message=res.violation_return_message,
|
user_message=res.violation_return_message,
|
||||||
metadata={
|
metadata={
|
||||||
"violation_type": res.violation_type,
|
"violation_type": res.violation_type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
|
||||||
raise SafetyException(violation)
|
|
||||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
|
||||||
cprint(
|
|
||||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
|
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue