forked from phoenix-oss/llama-stack-mirror
bug fix for safety violation
This commit is contained in:
parent
70fb70a71c
commit
e5bdd6615a
3 changed files with 15 additions and 12 deletions
|
@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
# CompletionMessage itself in the ShieldResponse
|
# CompletionMessage itself in the ShieldResponse
|
||||||
messages.append(
|
messages.append(
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
content=violation.user_message,
|
content=step.violation.user_message,
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class ShieldRunnerMixin:
|
||||||
async def run_multiple_shields(
|
async def run_multiple_shields(
|
||||||
self, messages: List[Message], shields: List[str]
|
self, messages: List[Message], shields: List[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
responses = await asyncio.gather(
|
await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.safety_api.run_shield(
|
self.safety_api.run_shield(
|
||||||
shield_type=shield_type,
|
shield_type=shield_type,
|
||||||
|
@ -43,13 +43,3 @@ class ShieldRunnerMixin:
|
||||||
for shield_type in shields
|
for shield_type in shields
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
for shield, r in zip(shields, responses):
|
|
||||||
if r.violation:
|
|
||||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
|
||||||
raise SafetyException(r)
|
|
||||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
|
||||||
cprint(
|
|
||||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
|
|
|
@ -10,6 +10,11 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
from llama_stack.providers.impls.meta_reference.agents.safety import SafetyException
|
||||||
|
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||||
|
OnViolationAction,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceShieldType, SafetyConfig
|
from .config import MetaReferenceShieldType, SafetyConfig
|
||||||
|
|
||||||
from .shields import (
|
from .shields import (
|
||||||
|
@ -78,6 +83,14 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||||
|
raise SafetyException(violation)
|
||||||
|
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||||
|
cprint(
|
||||||
|
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
return RunShieldResponse(violation=violation)
|
return RunShieldResponse(violation=violation)
|
||||||
|
|
||||||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue