forked from phoenix-oss/llama-stack-mirror
add bedrock distribution code (#358)
* add bedrock distribution code * fix linter error * add bedrock shields support * linter fixes * working bedrock safety * change to return only one violation * remove env var reading * refereshable boto credentials * remove env vars * address raghu's feedback * fix session_ttl passing --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
6ebd553da5
commit
093c9f1987
16 changed files with 429 additions and 135 deletions
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
self, messages: List[Message], identifiers: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shield_types
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shield_types, responses):
|
||||
for identifier, response in zip(identifiers, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
|||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
f"[Warn]{identifier} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue