working bedrock safety

This commit is contained in:
Dinesh Yeduguru 2024-11-05 14:48:25 -08:00
parent a4fd91fe51
commit df76c9b484
7 changed files with 63 additions and 44 deletions

View file

@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
class ShieldStore(Protocol): class ShieldStore(Protocol):
def get_shield(self, identifier: str) -> ShieldDef: ... async def get_shield(self, identifier: str) -> ShieldDef: ...
@runtime_checkable @runtime_checkable
@ -48,5 +48,5 @@ class Safety(Protocol):
@webmethod(route="/safety/run_shield") @webmethod(route="/safety/run_shield")
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...

View file

@ -46,7 +46,7 @@ class Shields(Protocol):
async def list_shields(self) -> List[ShieldDefWithProvider]: ... async def list_shields(self) -> List[ShieldDefWithProvider]: ...
@webmethod(route="/shields/get", method="GET") @webmethod(route="/shields/get", method="GET")
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ... async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ...
@webmethod(route="/shields/register", method="POST") @webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ... async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...

View file

@ -154,12 +154,12 @@ class SafetyRouter(Safety):
async def run_shield( async def run_shield(
self, self,
shield_type: str, identifier: str,
messages: List[Message], messages: List[Message],
params: Dict[str, Any] = None, params: Dict[str, Any] = None,
) -> RunShieldResponse: ) -> RunShieldResponse:
return await self.routing_table.get_provider_impl(shield_type).run_shield( return await self.routing_table.get_provider_impl(identifier).run_shield(
shield_type=shield_type, identifier=identifier,
messages=messages, messages=messages,
params=params, params=params,
) )

View file

@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
async def list_shields(self) -> List[ShieldDef]: async def list_shields(self) -> List[ShieldDef]:
return await self.get_all_with_type("shield") return await self.get_all_with_type("shield")
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
return await self.get_object_by_identifier(shield_type) return await self.get_object_by_identifier(identifier)
async def register_shield(self, shield: ShieldDefWithProvider) -> None: async def register_shield(self, shield: ShieldDefWithProvider) -> None:
await self.register_object(shield) await self.register_object(shield)

View file

@ -16,6 +16,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
import os
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
@ -443,8 +445,10 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
retries_config = { retries_config = {
k: v k: v
for k, v in dict( for k, v in dict(
total_max_attempts=config.total_max_attempts, total_max_attempts=os.environ.get(
mode=config.retry_mode, "AWS_MAX_ATTEMPTS", config.total_max_attempts
),
mode=os.environ.get("AWS_RETRY_MODE", config.retry_mode),
).items() ).items()
if v is not None if v is not None
} }
@ -452,7 +456,7 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
config_args = { config_args = {
k: v k: v
for k, v in dict( for k, v in dict(
region_name=config.region_name, region_name=os.environ.get("AWS_DEFAULT_REGION", config.region_name),
retries=retries_config if retries_config else None, retries=retries_config if retries_config else None,
connect_timeout=config.connect_timeout, connect_timeout=config.connect_timeout,
read_timeout=config.read_timeout, read_timeout=config.read_timeout,
@ -463,17 +467,21 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
boto3_config = Config(**config_args) boto3_config = Config(**config_args)
session_args = { session_args = {
k: v "aws_access_key_id": os.environ.get(
for k, v in dict( "AWS_ACCESS_KEY_ID", config.aws_access_key_id
aws_access_key_id=config.aws_access_key_id, ),
aws_secret_access_key=config.aws_secret_access_key, "aws_secret_access_key": os.environ.get(
aws_session_token=config.aws_session_token, "AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
region_name=config.region_name, ),
profile_name=config.profile_name, "aws_session_token": os.environ.get(
).items() "AWS_SESSION_TOKEN", config.aws_session_token
if v is not None ),
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
} }
boto3_session = boto3.session.Session(**session_args) # Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client("bedrock-runtime", config=boto3_config) return boto3_session.client("bedrock-runtime", config=boto3_config)

View file

@ -6,6 +6,7 @@
import json import json
import logging import logging
import os
from typing import Any, Dict, List from typing import Any, Dict, List
@ -27,20 +28,25 @@ BEDROCK_SUPPORTED_SHIELDS = [
def _create_bedrock_client(config: BedrockSafetyConfig, name: str): def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
# Use environment variables by default, fall back to config values
session_args = { session_args = {
k: v "aws_access_key_id": os.environ.get(
for k, v in dict( "AWS_ACCESS_KEY_ID", config.aws_access_key_id
aws_access_key_id=config.aws_access_key_id, ),
aws_secret_access_key=config.aws_secret_access_key, "aws_secret_access_key": os.environ.get(
aws_session_token=config.aws_session_token, "AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
region_name=config.region_name, ),
profile_name=config.profile_name, "aws_session_token": os.environ.get(
).items() "AWS_SESSION_TOKEN", config.aws_session_token
if v is not None ),
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
} }
boto3_session = boto3.session.Session(**session_args) # Remove None values
session_args = {k: v for k, v in session_args.items() if v is not None}
boto3_session = boto3.session.Session(**session_args)
return boto3_session.client(name) return boto3_session.client(name)
@ -77,15 +83,16 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
"guardrailVersion": guardrail["version"], "guardrailVersion": guardrail["version"],
}, },
) )
self.registered_shields.append(shield_def)
shields.append(shield_def) shields.append(shield_def)
return shields return shields
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(shield_type) shield_def = await self.shield_store.get_shield(identifier)
if not shield_def: if not shield_def:
raise ValueError(f"Unknown shield {shield_type}") raise ValueError(f"Unknown shield {identifier}")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [ ```content = [
@ -128,10 +135,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
# guardrails returns a list - however for this implementation we will leverage the last values # guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment) metadata = dict(assessment)
return SafetyViolation( return RunShieldResponse(
user_message=user_message, violations=[
violation_level=ViolationLevel.ERROR, SafetyViolation(
metadata=metadata, user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
]
) )
return None return RunShieldResponse(violations=[])

View file

@ -32,18 +32,18 @@ class ShieldRunnerMixin:
self.output_shields = output_shields self.output_shields = output_shields
async def run_multiple_shields( async def run_multiple_shields(
self, messages: List[Message], shield_types: List[str] self, messages: List[Message], identifiers: List[str]
) -> None: ) -> None:
responses = await asyncio.gather( responses = await asyncio.gather(
*[ *[
self.safety_api.run_shield( self.safety_api.run_shield(
shield_type=shield_type, identifier=identifier,
messages=messages, messages=messages,
) )
for shield_type in shield_types for identifier in identifiers
] ]
) )
for shield_type, response in zip(shield_types, responses): for identifier, response in zip(identifiers, responses):
if not response.violation: if not response.violation:
continue continue
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
raise SafetyException(violation) raise SafetyException(violation)
elif violation.violation_level == ViolationLevel.WARN: elif violation.violation_level == ViolationLevel.WARN:
cprint( cprint(
f"[Warn]{shield_type} raised a warning", f"[Warn]{identifier} raised a warning",
color="red", color="red",
) )