working bedrock safety

This commit is contained in:
Dinesh Yeduguru 2024-11-05 14:48:25 -08:00
parent a4fd91fe51
commit df76c9b484
7 changed files with 63 additions and 44 deletions

View file

@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ...
async def get_shield(self, identifier: str) -> ShieldDef: ...
@runtime_checkable
@ -48,5 +48,5 @@ class Safety(Protocol):
@webmethod(route="/safety/run_shield")
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ...

View file

@ -46,7 +46,7 @@ class Shields(Protocol):
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
@webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ...
@webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...

View file

@ -154,12 +154,12 @@ class SafetyRouter(Safety):
async def run_shield(
self,
shield_type: str,
identifier: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_type).run_shield(
shield_type=shield_type,
return await self.routing_table.get_provider_impl(identifier).run_shield(
identifier=identifier,
messages=messages,
params=params,
)

View file

@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]:
return await self.get_all_with_type("shield")
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
return await self.get_object_by_identifier(shield_type)
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
return await self.get_object_by_identifier(identifier)
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
await self.register_object(shield)

View file

@ -16,6 +16,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403
import os
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
@ -443,8 +445,10 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = {
k: v
for k, v in dict(
total_max_attempts=config.total_max_attempts,
mode=config.retry_mode,
total_max_attempts=os.environ.get(
"AWS_MAX_ATTEMPTS", config.total_max_attempts
),
mode=os.environ.get("AWS_RETRY_MODE", config.retry_mode),
).items()
if v is not None
}
@ -452,7 +456,7 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
config_args = {
k: v
for k, v in dict(
region_name=config.region_name,
region_name=os.environ.get("AWS_DEFAULT_REGION", config.region_name),
retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout,
@ -463,17 +467,21 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
boto3_config = Config(**config_args)
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
"aws_access_key_id": os.environ.get(
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
),
"aws_secret_access_key": os.environ.get(
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
),
"aws_session_token": os.environ.get(
"AWS_SESSION_TOKEN", config.aws_session_token
),
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
}
boto3_session = boto3.session.Session(**session_args)
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config)

View file

@ -6,6 +6,7 @@
import json
import logging
import os
from typing import Any, Dict, List
@ -27,20 +28,25 @@ BEDROCK_SUPPORTED_SHIELDS = [
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
# Use environment variables by default, fall back to config values
session_args = {
k: v
for k, v in dict(
aws_access_key_id=config.aws_access_key_id,
aws_secret_access_key=config.aws_secret_access_key,
aws_session_token=config.aws_session_token,
region_name=config.region_name,
profile_name=config.profile_name,
).items()
if v is not None
"aws_access_key_id": os.environ.get(
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
),
"aws_secret_access_key": os.environ.get(
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
),
"aws_session_token": os.environ.get(
"AWS_SESSION_TOKEN", config.aws_session_token
),
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
}
boto3_session = boto3.session.Session(**session_args)
# Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(name)
@ -77,15 +83,16 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
"guardrailVersion": guardrail["version"],
},
)
self.registered_shields.append(shield_def)
shields.append(shield_def)
return shields
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type)
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {shield_type}")
raise ValueError(f"Unknown shield {identifier}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -128,10 +135,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
return RunShieldResponse(
violations=[
SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
]
)
return None
return RunShieldResponse(violations=[])

View file

@ -32,18 +32,18 @@ class ShieldRunnerMixin:
self.output_shields = output_shields
async def run_multiple_shields(
self, messages: List[Message], shield_types: List[str]
self, messages: List[Message], identifiers: List[str]
) -> None:
responses = await asyncio.gather(
*[
self.safety_api.run_shield(
shield_type=shield_type,
identifier=identifier,
messages=messages,
)
for shield_type in shield_types
for identifier in identifiers
]
)
for shield_type, response in zip(shield_types, responses):
for identifier, response in zip(identifiers, responses):
if not response.violation:
continue
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN:
cprint(
f"[Warn]{shield_type} raised a warning",
f"[Warn]{identifier} raised a warning",
color="red",
)