From f45705cd105d458b9e3ce8a3873fc5b2749bea77 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 24 Sep 2024 19:27:03 -0700 Subject: [PATCH] Some lightweight cleanup and renaming for bedrock safety adapter --- llama_stack/cli/stack/configure.py | 2 +- .../adapters/safety/bedrock/__init__.py | 10 +- .../adapters/safety/bedrock/bedrock.py | 122 +++++++++--------- .../adapters/safety/bedrock/config.py | 14 +- llama_stack/providers/registry/safety.py | 6 +- 5 files changed, 76 insertions(+), 78 deletions(-) diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py index 58f383a37..135962d4d 100644 --- a/llama_stack/cli/stack/configure.py +++ b/llama_stack/cli/stack/configure.py @@ -160,7 +160,7 @@ class StackConfigure(Subcommand): f.write(yaml.dump(to_write, sort_keys=False)) cprint( - f"> YAML configuration has been written to {run_config_file}.", + f"> YAML configuration has been written to `{run_config_file}`.", color="blue", ) diff --git a/llama_stack/providers/adapters/safety/bedrock/__init__.py b/llama_stack/providers/adapters/safety/bedrock/__init__.py index 0b10015a1..c602156a6 100644 --- a/llama_stack/providers/adapters/safety/bedrock/__init__.py +++ b/llama_stack/providers/adapters/safety/bedrock/__init__.py @@ -7,12 +7,12 @@ from typing import Any -from .config import BedrockShieldConfig +from .config import BedrockSafetyConfig -async def get_adapter_impl(config: BedrockShieldConfig, _deps) -> Any: - from .bedrock import BedrockShieldAdapter +async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any: + from .bedrock import BedrockSafetyAdapter - impl = BedrockShieldAdapter(config) + impl = BedrockSafetyAdapter(config) await impl.initialize() - return impl \ No newline at end of file + return impl diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index 2de91c2ab..a3acda1ce 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -5,99 +5,105 @@ # the root directory of this source tree. -from typing import Any, AsyncGenerator, Dict -from .config import BedrockShieldConfig import traceback -import asyncio -from enum import Enum -from typing import List -from pydantic import BaseModel, validator +from typing import Any, Dict, List + +from .config import BedrockSafetyConfig from llama_stack.apis.safety import * # noqa from llama_models.llama3.api.datatypes import * # noqa: F403 -import boto3 import json import logging +import boto3 + + logger = logging.getLogger(__name__) -class BedrockShieldAdapter(Safety): - def __init__(self, config: BedrockShieldConfig) -> None: + +class BedrockSafetyAdapter(Safety): + def __init__(self, config: BedrockSafetyConfig) -> None: self.config = config - async def initialize(self) -> None: + if not self.config.aws_profile: + raise RuntimeError( + f"Missing boto_client aws_profile in model info::{self.config}" + ) + try: - if not self.config.aws_profile: - raise RuntimeError(f"Missing boto_client aws_profile in model info::{self.config}") print(f"initializing with profile --- > {self.config}::") self.boto_client_profile = self.config.aws_profile - self.boto_client = boto3.Session(profile_name=self.boto_client_profile).client('bedrock-runtime') + self.boto_client = boto3.Session( + profile_name=self.boto_client_profile + ).client("bedrock-runtime") except Exception as e: - import traceback - - traceback.print_exc() raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e async def shutdown(self) -> None: pass - async def run_shield(self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None) -> RunShieldResponse: - """ This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format - ```content = [ - { - "text": { - "text": "Is the AB503 Product a better investment than the S&P 500?" - } + async def run_shield( + self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None + ) -> RunShieldResponse: + """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format + ```content = [ + { + "text": { + "text": "Is the AB503 Product a better investment than the S&P 500?" } - ]``` - However the incoming messages are of this type UserMessage(content=....) coming from - https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py + } + ]``` + However the incoming messages are of this type UserMessage(content=....) coming from + https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py - They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] + They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] """ - ret_violation = None try: logger.debug(f"run_shield::{params}::messages={messages}") - if not 'guardrailIdentifier' in params: - raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing GuardrailID in request") - - if not 'guardrailVersion' in params: - raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing guardrailVersion in request") - - #- convert the messages into format Bedrock expects + if "guardrailIdentifier" not in params: + raise RuntimeError( + "Error running request for BedrockGaurdrails:Missing GuardrailID in request" + ) + + if "guardrailVersion" not in params: + raise RuntimeError( + "Error running request for BedrockGaurdrails:Missing guardrailVersion in request" + ) + + # - convert the messages into format Bedrock expects content_messages = [] for message in messages: content_messages.append({"text": {"text": message.content}}) - logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") + logger.debug( + f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:" + ) response = self.boto_client.apply_guardrail( - guardrailIdentifier=params.get('guardrailIdentifier'), - guardrailVersion=params.get('guardrailVersion'), - source='OUTPUT', # or 'INPUT' depending on your use case - content=content_messages + guardrailIdentifier=params.get("guardrailIdentifier"), + guardrailVersion=params.get("guardrailVersion"), + source="OUTPUT", # or 'INPUT' depending on your use case + content=content_messages, ) logger.debug(f"run_shield:: response: {response}::") - if response['action'] == 'GUARDRAIL_INTERVENED': - user_message="" - metadata={} - for output in response['outputs']: + if response["action"] == "GUARDRAIL_INTERVENED": + user_message = "" + metadata = {} + for output in response["outputs"]: # guardrails returns a list - however for this implementation we will leverage the last values - user_message=output['text'] - for assessment in response['assessments']: + user_message = output["text"] + for assessment in response["assessments"]: # guardrails returns a list - however for this implementation we will leverage the last values - metadata = dict(assessment) - ret_violation = SafetyViolation( - user_message=user_message, + metadata = dict(assessment) + return SafetyViolation( + user_message=user_message, violation_level=ViolationLevel.ERROR, - metadata=metadata + metadata=metadata, ) - - except: + + except Exception: error_str = traceback.format_exc() - print(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!") - logger.error(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!") - #raise RuntimeError(f"Error running request for BedrockGaurdrails: {error_str}:") - - return ret_violation - + logger.error( + f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!" + ) + return None diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py index 69c4a9609..2a8585262 100644 --- a/llama_stack/providers/adapters/safety/bedrock/config.py +++ b/llama_stack/providers/adapters/safety/bedrock/config.py @@ -4,21 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional - -from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field -import boto3 -@json_schema_type -class BedrockShieldConfig(BaseModel): +class BedrockSafetyConfig(BaseModel): """Configuration information for a guardrail that you want to use in the request.""" - aws_profile: Optional[str] = Field( - default='default', + aws_profile: str = Field( + default="default", description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation", ) - - - diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 09aed4982..1f353912b 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -41,10 +41,10 @@ def available_providers() -> List[ProviderSpec]: remote_provider_spec( api=Api.safety, adapter=AdapterSpec( - adapter_id="bedrock_guardrails", - pip_packages=['boto3',], + adapter_id="bedrock", + pip_packages=["boto3"], module="llama_stack.providers.adapters.safety.bedrock", - config_class="llama_stack.providers.adapters.safety.bedrock.config.BedrockShieldConfig", + config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig", ), ), remote_provider_spec(