From 059e50b389ebf02b54b5719e30a480c54adf6d3d Mon Sep 17 00:00:00 2001 From: rsgrewal-aws <102243526+rsgrewal-aws@users.noreply.github.com> Date: Tue, 24 Sep 2024 19:16:55 -0700 Subject: [PATCH] [aws-bedrock] Support for Bedrock Safety adapter (#96) --- .../adapters/safety/bedrock/__init__.py | 18 +++ .../adapters/safety/bedrock/bedrock.py | 103 ++++++++++++++++++ .../adapters/safety/bedrock/config.py | 24 ++++ llama_stack/providers/registry/safety.py | 9 ++ 4 files changed, 154 insertions(+) create mode 100644 llama_stack/providers/adapters/safety/bedrock/__init__.py create mode 100644 llama_stack/providers/adapters/safety/bedrock/bedrock.py create mode 100644 llama_stack/providers/adapters/safety/bedrock/config.py diff --git a/llama_stack/providers/adapters/safety/bedrock/__init__.py b/llama_stack/providers/adapters/safety/bedrock/__init__.py new file mode 100644 index 000000000..0b10015a1 --- /dev/null +++ b/llama_stack/providers/adapters/safety/bedrock/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any + +from .config import BedrockShieldConfig + + +async def get_adapter_impl(config: BedrockShieldConfig, _deps) -> Any: + from .bedrock import BedrockShieldAdapter + + impl = BedrockShieldAdapter(config) + await impl.initialize() + return impl \ No newline at end of file diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py new file mode 100644 index 000000000..2de91c2ab --- /dev/null +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any, AsyncGenerator, Dict +from .config import BedrockShieldConfig +import traceback +import asyncio +from enum import Enum +from typing import List +from pydantic import BaseModel, validator +from llama_stack.apis.safety import * # noqa +from llama_models.llama3.api.datatypes import * # noqa: F403 +import boto3 +import json +import logging + +logger = logging.getLogger(__name__) + +class BedrockShieldAdapter(Safety): + def __init__(self, config: BedrockShieldConfig) -> None: + self.config = config + + + async def initialize(self) -> None: + try: + if not self.config.aws_profile: + raise RuntimeError(f"Missing boto_client aws_profile in model info::{self.config}") + print(f"initializing with profile --- > {self.config}::") + self.boto_client_profile = self.config.aws_profile + self.boto_client = boto3.Session(profile_name=self.boto_client_profile).client('bedrock-runtime') + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e + + async def shutdown(self) -> None: + pass + + async def run_shield(self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None) -> RunShieldResponse: + """ This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format + ```content = [ + { + "text": { + "text": "Is the AB503 Product a better investment than the S&P 500?" + } + } + ]``` + However the incoming messages are of this type UserMessage(content=....) coming from + https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py + + They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] + """ + ret_violation = None + try: + logger.debug(f"run_shield::{params}::messages={messages}") + if not 'guardrailIdentifier' in params: + raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing GuardrailID in request") + + if not 'guardrailVersion' in params: + raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing guardrailVersion in request") + + #- convert the messages into format Bedrock expects + content_messages = [] + for message in messages: + content_messages.append({"text": {"text": message.content}}) + logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + + response = self.boto_client.apply_guardrail( + guardrailIdentifier=params.get('guardrailIdentifier'), + guardrailVersion=params.get('guardrailVersion'), + source='OUTPUT', # or 'INPUT' depending on your use case + content=content_messages + ) + logger.debug(f"run_shield:: response: {response}::") + if response['action'] == 'GUARDRAIL_INTERVENED': + user_message="" + metadata={} + for output in response['outputs']: + # guardrails returns a list - however for this implementation we will leverage the last values + user_message=output['text'] + for assessment in response['assessments']: + # guardrails returns a list - however for this implementation we will leverage the last values + metadata = dict(assessment) + ret_violation = SafetyViolation( + user_message=user_message, + violation_level=ViolationLevel.ERROR, + metadata=metadata + ) + + except: + error_str = traceback.format_exc() + print(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!") + logger.error(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!") + #raise RuntimeError(f"Error running request for BedrockGaurdrails: {error_str}:") + + return ret_violation + + diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py new file mode 100644 index 000000000..69c4a9609 --- /dev/null +++ b/llama_stack/providers/adapters/safety/bedrock/config.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Optional + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field +import boto3 + + +@json_schema_type +class BedrockShieldConfig(BaseModel): + """Configuration information for a guardrail that you want to use in the request.""" + + aws_profile: Optional[str] = Field( + default='default', + description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation", + ) + + + diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 0a012b1df..202690264 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -40,6 +40,15 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.adapters.safety.sample.SampleConfig", ), ), + remote_provider_spec( + api=Api.safety, + adapter=AdapterSpec( + adapter_id="bedrock_guardrails", + pip_packages=['boto3',], + module="llama_stack.providers.adapters.safety.bedrock", + config_class="llama_stack.providers.adapters.safety.bedrock.config.BedrockShieldConfig", + ), + ), remote_provider_spec( api=Api.safety, adapter=AdapterSpec(