Some lightweight cleanup and renaming for bedrock safety adapter

This commit is contained in:
Ashwin Bharambe 2024-09-24 19:27:03 -07:00
parent a2465f3f9c
commit f45705cd10
5 changed files with 76 additions and 78 deletions

View file

@ -160,7 +160,7 @@ class StackConfigure(Subcommand):
f.write(yaml.dump(to_write, sort_keys=False)) f.write(yaml.dump(to_write, sort_keys=False))
cprint( cprint(
f"> YAML configuration has been written to {run_config_file}.", f"> YAML configuration has been written to `{run_config_file}`.",
color="blue", color="blue",
) )

View file

@ -7,12 +7,12 @@
from typing import Any from typing import Any
from .config import BedrockShieldConfig from .config import BedrockSafetyConfig
async def get_adapter_impl(config: BedrockShieldConfig, _deps) -> Any: async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any:
from .bedrock import BedrockShieldAdapter from .bedrock import BedrockSafetyAdapter
impl = BedrockShieldAdapter(config) impl = BedrockSafetyAdapter(config)
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -5,43 +5,46 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict
from .config import BedrockShieldConfig
import traceback import traceback
import asyncio from typing import Any, Dict, List
from enum import Enum
from typing import List from .config import BedrockSafetyConfig
from pydantic import BaseModel, validator
from llama_stack.apis.safety import * # noqa from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
import boto3
import json import json
import logging import logging
import boto3
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BedrockShieldAdapter(Safety):
def __init__(self, config: BedrockShieldConfig) -> None: class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
try:
if not self.config.aws_profile: if not self.config.aws_profile:
raise RuntimeError(f"Missing boto_client aws_profile in model info::{self.config}") raise RuntimeError(
f"Missing boto_client aws_profile in model info::{self.config}"
)
try:
print(f"initializing with profile --- > {self.config}::") print(f"initializing with profile --- > {self.config}::")
self.boto_client_profile = self.config.aws_profile self.boto_client_profile = self.config.aws_profile
self.boto_client = boto3.Session(profile_name=self.boto_client_profile).client('bedrock-runtime') self.boto_client = boto3.Session(
profile_name=self.boto_client_profile
).client("bedrock-runtime")
except Exception as e: except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def run_shield(self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None) -> RunShieldResponse: async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format """This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [ ```content = [
{ {
@ -55,49 +58,52 @@ class BedrockShieldAdapter(Safety):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"] They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
""" """
ret_violation = None
try: try:
logger.debug(f"run_shield::{params}::messages={messages}") logger.debug(f"run_shield::{params}::messages={messages}")
if not 'guardrailIdentifier' in params: if "guardrailIdentifier" not in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing GuardrailID in request") raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if not 'guardrailVersion' in params: if "guardrailVersion" not in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing guardrailVersion in request") raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
# - convert the messages into format Bedrock expects # - convert the messages into format Bedrock expects
content_messages = [] content_messages = []
for message in messages: for message in messages:
content_messages.append({"text": {"text": message.content}}) content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:") logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail( response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get('guardrailIdentifier'), guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get('guardrailVersion'), guardrailVersion=params.get("guardrailVersion"),
source='OUTPUT', # or 'INPUT' depending on your use case source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages content=content_messages,
) )
logger.debug(f"run_shield:: response: {response}::") logger.debug(f"run_shield:: response: {response}::")
if response['action'] == 'GUARDRAIL_INTERVENED': if response["action"] == "GUARDRAIL_INTERVENED":
user_message = "" user_message = ""
metadata = {} metadata = {}
for output in response['outputs']: for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values # guardrails returns a list - however for this implementation we will leverage the last values
user_message=output['text'] user_message = output["text"]
for assessment in response['assessments']: for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values # guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment) metadata = dict(assessment)
ret_violation = SafetyViolation( return SafetyViolation(
user_message=user_message, user_message=user_message,
violation_level=ViolationLevel.ERROR, violation_level=ViolationLevel.ERROR,
metadata=metadata metadata=metadata,
) )
except: except Exception:
error_str = traceback.format_exc() error_str = traceback.format_exc()
print(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!") logger.error(
logger.error(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!") f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
#raise RuntimeError(f"Error running request for BedrockGaurdrails: {error_str}:") )
return ret_violation
return None

View file

@ -4,21 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import boto3
@json_schema_type class BedrockSafetyConfig(BaseModel):
class BedrockShieldConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request.""" """Configuration information for a guardrail that you want to use in the request."""
aws_profile: Optional[str] = Field( aws_profile: str = Field(
default='default', default="default",
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation", description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
) )

View file

@ -41,10 +41,10 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec( remote_provider_spec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_id="bedrock_guardrails", adapter_id="bedrock",
pip_packages=['boto3',], pip_packages=["boto3"],
module="llama_stack.providers.adapters.safety.bedrock", module="llama_stack.providers.adapters.safety.bedrock",
config_class="llama_stack.providers.adapters.safety.bedrock.config.BedrockShieldConfig", config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
), ),
), ),
remote_provider_spec( remote_provider_spec(