forked from phoenix-oss/llama-stack-mirror
bug fix for safety violation
This commit is contained in:
parent
70fb70a71c
commit
e5bdd6615a
3 changed files with 15 additions and 12 deletions
|
@ -130,7 +130,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=violation.user_message,
|
||||
content=step.violation.user_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -34,7 +34,7 @@ class ShieldRunnerMixin:
|
|||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shields: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
|
@ -43,13 +43,3 @@ class ShieldRunnerMixin:
|
|||
for shield_type in shields
|
||||
]
|
||||
)
|
||||
|
||||
for shield, r in zip(shields, responses):
|
||||
if r.violation:
|
||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||
raise SafetyException(r)
|
||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue