From df76c9b4845e353888d8c78d547642565767e670 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 5 Nov 2024 14:48:25 -0800 Subject: [PATCH] working bedrock safety --- llama_stack/apis/safety/safety.py | 4 +- llama_stack/apis/shields/shields.py | 2 +- llama_stack/distribution/routers/routers.py | 6 +-- .../distribution/routers/routing_tables.py | 4 +- .../adapters/inference/bedrock/bedrock.py | 34 +++++++++----- .../adapters/safety/bedrock/bedrock.py | 47 ++++++++++++------- .../impls/meta_reference/agents/safety.py | 10 ++-- 7 files changed, 63 insertions(+), 44 deletions(-) diff --git a/llama_stack/apis/safety/safety.py b/llama_stack/apis/safety/safety.py index f3615dc4b..0b74fd259 100644 --- a/llama_stack/apis/safety/safety.py +++ b/llama_stack/apis/safety/safety.py @@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel): class ShieldStore(Protocol): - def get_shield(self, identifier: str) -> ShieldDef: ... + async def get_shield(self, identifier: str) -> ShieldDef: ... @runtime_checkable @@ -48,5 +48,5 @@ class Safety(Protocol): @webmethod(route="/safety/run_shield") async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: ... diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 7c8e3939a..fd5634442 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -46,7 +46,7 @@ class Shields(Protocol): async def list_shields(self) -> List[ShieldDefWithProvider]: ... @webmethod(route="/shields/get", method="GET") - async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ... + async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ... @webmethod(route="/shields/register", method="POST") async def register_shield(self, shield: ShieldDefWithProvider) -> None: ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 348d8449d..760dbaf2f 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -154,12 +154,12 @@ class SafetyRouter(Safety): async def run_shield( self, - shield_type: str, + identifier: str, messages: List[Message], params: Dict[str, Any] = None, ) -> RunShieldResponse: - return await self.routing_table.get_provider_impl(shield_type).run_shield( - shield_type=shield_type, + return await self.routing_table.get_provider_impl(identifier).run_shield( + identifier=identifier, messages=messages, params=params, ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 6297182bc..bcf125bec 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: return await self.get_all_with_type("shield") - async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: - return await self.get_object_by_identifier(shield_type) + async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: + return await self.get_object_by_identifier(identifier) async def register_shield(self, shield: ShieldDefWithProvider) -> None: await self.register_object(shield) diff --git a/llama_stack/providers/adapters/inference/bedrock/bedrock.py b/llama_stack/providers/adapters/inference/bedrock/bedrock.py index caf886c0b..83f799cbd 100644 --- a/llama_stack/providers/adapters/inference/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/inference/bedrock/bedrock.py @@ -16,6 +16,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.apis.inference import * # noqa: F403 +import os + from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig @@ -443,8 +445,10 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient: retries_config = { k: v for k, v in dict( - total_max_attempts=config.total_max_attempts, - mode=config.retry_mode, + total_max_attempts=os.environ.get( + "AWS_MAX_ATTEMPTS", config.total_max_attempts + ), + mode=os.environ.get("AWS_RETRY_MODE", config.retry_mode), ).items() if v is not None } @@ -452,7 +456,7 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient: config_args = { k: v for k, v in dict( - region_name=config.region_name, + region_name=os.environ.get("AWS_DEFAULT_REGION", config.region_name), retries=retries_config if retries_config else None, connect_timeout=config.connect_timeout, read_timeout=config.read_timeout, @@ -463,17 +467,21 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient: boto3_config = Config(**config_args) session_args = { - k: v - for k, v in dict( - aws_access_key_id=config.aws_access_key_id, - aws_secret_access_key=config.aws_secret_access_key, - aws_session_token=config.aws_session_token, - region_name=config.region_name, - profile_name=config.profile_name, - ).items() - if v is not None + "aws_access_key_id": os.environ.get( + "AWS_ACCESS_KEY_ID", config.aws_access_key_id + ), + "aws_secret_access_key": os.environ.get( + "AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key + ), + "aws_session_token": os.environ.get( + "AWS_SESSION_TOKEN", config.aws_session_token + ), + "region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name), + "profile_name": os.environ.get("AWS_PROFILE", config.profile_name), } - boto3_session = boto3.session.Session(**session_args) + # Remove None values + session_args = {k: v for k, v in session_args.items() if v is not None} + boto3_session = boto3.session.Session(**session_args) return boto3_session.client("bedrock-runtime", config=boto3_config) diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index e22fb1130..54ace0cda 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -6,6 +6,7 @@ import json import logging +import os from typing import Any, Dict, List @@ -27,20 +28,25 @@ BEDROCK_SUPPORTED_SHIELDS = [ def _create_bedrock_client(config: BedrockSafetyConfig, name: str): + # Use environment variables by default, fall back to config values session_args = { - k: v - for k, v in dict( - aws_access_key_id=config.aws_access_key_id, - aws_secret_access_key=config.aws_secret_access_key, - aws_session_token=config.aws_session_token, - region_name=config.region_name, - profile_name=config.profile_name, - ).items() - if v is not None + "aws_access_key_id": os.environ.get( + "AWS_ACCESS_KEY_ID", config.aws_access_key_id + ), + "aws_secret_access_key": os.environ.get( + "AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key + ), + "aws_session_token": os.environ.get( + "AWS_SESSION_TOKEN", config.aws_session_token + ), + "region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name), + "profile_name": os.environ.get("AWS_PROFILE", config.profile_name), } - boto3_session = boto3.session.Session(**session_args) + # Remove None values + session_args = {k: v for k, v in session_args.items() if v is not None} + boto3_session = boto3.session.Session(**session_args) return boto3_session.client(name) @@ -77,15 +83,16 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): "guardrailVersion": guardrail["version"], }, ) + self.registered_shields.append(shield_def) shields.append(shield_def) return shields async def run_shield( - self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + self, identifier: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: - shield_def = await self.shield_store.get_shield(shield_type) + shield_def = await self.shield_store.get_shield(identifier) if not shield_def: - raise ValueError(f"Unknown shield {shield_type}") + raise ValueError(f"Unknown shield {identifier}") """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format ```content = [ @@ -128,10 +135,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): # guardrails returns a list - however for this implementation we will leverage the last values metadata = dict(assessment) - return SafetyViolation( - user_message=user_message, - violation_level=ViolationLevel.ERROR, - metadata=metadata, + return RunShieldResponse( + violations=[ + SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata, + ) + ] ) - return None + return RunShieldResponse(violations=[]) diff --git a/llama_stack/providers/impls/meta_reference/agents/safety.py b/llama_stack/providers/impls/meta_reference/agents/safety.py index fb5821f6a..915ddd303 100644 --- a/llama_stack/providers/impls/meta_reference/agents/safety.py +++ b/llama_stack/providers/impls/meta_reference/agents/safety.py @@ -32,18 +32,18 @@ class ShieldRunnerMixin: self.output_shields = output_shields async def run_multiple_shields( - self, messages: List[Message], shield_types: List[str] + self, messages: List[Message], identifiers: List[str] ) -> None: responses = await asyncio.gather( *[ self.safety_api.run_shield( - shield_type=shield_type, + identifier=identifier, messages=messages, ) - for shield_type in shield_types + for identifier in identifiers ] ) - for shield_type, response in zip(shield_types, responses): + for identifier, response in zip(identifiers, responses): if not response.violation: continue @@ -52,6 +52,6 @@ class ShieldRunnerMixin: raise SafetyException(violation) elif violation.violation_level == ViolationLevel.WARN: cprint( - f"[Warn]{shield_type} raised a warning", + f"[Warn]{identifier} raised a warning", color="red", )