working bedrock safety

This commit is contained in:
Dinesh Yeduguru 2024-11-05 14:48:25 -08:00
parent a4fd91fe51
commit df76c9b484
7 changed files with 63 additions and 44 deletions

View file

@ -6,6 +6,7 @@
import json
import logging
import os
from typing import Any, Dict, List
@ -27,20 +28,25 @@ BEDROCK_SUPPORTED_SHIELDS = [
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
# Use environment variables by default, fall back to config values
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
"aws_access_key_id": os.environ.get(
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
),
"aws_secret_access_key": os.environ.get(
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
),
"aws_session_token": os.environ.get(
"AWS_SESSION_TOKEN", config.aws_session_token
),
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
}
boto3_session = boto3.session.Session(**session_args)
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(name)
@ -77,15 +83,16 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
"guardrailVersion": guardrail["version"],
},
)
self.registered_shields.append(shield_def)
shields.append(shield_def)
return shields
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
raise ValueError(f"Unknown shield {identifier}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -128,10 +135,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
return RunShieldResponse(
violations=[
SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
]
)
return None
return RunShieldResponse(violations=[])