mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-17 23:17:17 +00:00
working bedrock safety
This commit is contained in:
parent
a4fd91fe51
commit
df76c9b484
7 changed files with 63 additions and 44 deletions
|
@ -6,6 +6,7 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
@ -27,20 +28,25 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
|||
|
||||
|
||||
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
|
||||
# Use environment variables by default, fall back to config values
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
"aws_access_key_id": os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
|
||||
),
|
||||
"aws_secret_access_key": os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
|
||||
),
|
||||
"aws_session_token": os.environ.get(
|
||||
"AWS_SESSION_TOKEN", config.aws_session_token
|
||||
),
|
||||
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
# Remove None values
|
||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
return boto3_session.client(name)
|
||||
|
||||
|
||||
|
@ -77,15 +83,16 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
"guardrailVersion": guardrail["version"],
|
||||
},
|
||||
)
|
||||
self.registered_shields.append(shield_def)
|
||||
shields.append(shield_def)
|
||||
return shields
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
|
@ -128,10 +135,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
return RunShieldResponse(
|
||||
violations=[
|
||||
SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
return RunShieldResponse(violations=[])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue