forked from phoenix-oss/llama-stack-mirror
Initial commit
This commit is contained in:
commit
5d5acc8ed5
81 changed files with 4458 additions and 0 deletions
52
llama_toolchain/safety/shields/shield_runner.py
Normal file
52
llama_toolchain/safety/shields/shield_runner.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import Message, Role
|
||||
|
||||
from .base import OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
||||
|
||||
class SafetyException(Exception): # noqa: N818
|
||||
def __init__(self, response: ShieldResponse):
|
||||
self.response = response
|
||||
super().__init__(response.violation_return_message)
|
||||
|
||||
|
||||
class ShieldRunnerMixin:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shields: List[ShieldBase] = None,
|
||||
output_shields: List[ShieldBase] = None,
|
||||
):
|
||||
self.input_shields = input_shields
|
||||
self.output_shields = output_shields
|
||||
|
||||
async def run_shields(
|
||||
self, messages: List[Message], shields: List[ShieldBase]
|
||||
) -> List[ShieldResponse]:
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
# since this might be a tool call, first role might not be user
|
||||
if len(messages) > 0 and messages[0].role != Role.user.value:
|
||||
# TODO(ashwin): we need to change the type of the message, this kind of modification
|
||||
# is no longer appropriate
|
||||
messages[0].role = Role.user.value
|
||||
|
||||
results = await asyncio.gather(*[s.run(messages) for s in shields])
|
||||
for shield, r in zip(shields, results):
|
||||
if r.is_violation:
|
||||
if shield.on_violation_action == OnViolationAction.RAISE:
|
||||
raise SafetyException(r)
|
||||
elif shield.on_violation_action == OnViolationAction.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield.__class__.__name__} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
||||
return results
|
Loading…
Add table
Add a link
Reference in a new issue