Some lightweight cleanup and renaming for bedrock safety adapter

This commit is contained in:
Ashwin Bharambe 2024-09-24 19:27:03 -07:00
parent a2465f3f9c
commit f45705cd10
5 changed files with 76 additions and 78 deletions

View file

@ -160,7 +160,7 @@ class StackConfigure(Subcommand):
f.write(yaml.dump(to_write, sort_keys=False))
cprint(
f"> YAML configuration has been written to {run_config_file}.",
f"> YAML configuration has been written to `{run_config_file}`.",
color="blue",
)

View file

@ -7,12 +7,12 @@
from typing import Any
from .config import BedrockShieldConfig
from .config import BedrockSafetyConfig
async def get_adapter_impl(config: BedrockShieldConfig, _deps) -> Any:
from .bedrock import BedrockShieldAdapter
async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any:
from .bedrock import BedrockSafetyAdapter
impl = BedrockShieldAdapter(config)
impl = BedrockSafetyAdapter(config)
await impl.initialize()
return impl
return impl

View file

@ -5,99 +5,105 @@
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict
from .config import BedrockShieldConfig
import traceback
import asyncio
from enum import Enum
from typing import List
from pydantic import BaseModel, validator
from typing import Any, Dict, List
from .config import BedrockSafetyConfig
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
import boto3
import json
import logging
import boto3
logger = logging.getLogger(__name__)
class BedrockShieldAdapter(Safety):
def __init__(self, config: BedrockShieldConfig) -> None:
class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
if not self.config.aws_profile:
raise RuntimeError(
f"Missing boto_client aws_profile in model info::{self.config}"
)
try:
if not self.config.aws_profile:
raise RuntimeError(f"Missing boto_client aws_profile in model info::{self.config}")
print(f"initializing with profile --- > {self.config}::")
self.boto_client_profile = self.config.aws_profile
self.boto_client = boto3.Session(profile_name=self.boto_client_profile).client('bedrock-runtime')
self.boto_client = boto3.Session(
profile_name=self.boto_client_profile
).client("bedrock-runtime")
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
async def shutdown(self) -> None:
pass
async def run_shield(self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None) -> RunShieldResponse:
""" This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
ret_violation = None
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if not 'guardrailIdentifier' in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing GuardrailID in request")
if not 'guardrailVersion' in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing guardrailVersion in request")
#- convert the messages into format Bedrock expects
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get('guardrailIdentifier'),
guardrailVersion=params.get('guardrailVersion'),
source='OUTPUT', # or 'INPUT' depending on your use case
content=content_messages
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response['action'] == 'GUARDRAIL_INTERVENED':
user_message=""
metadata={}
for output in response['outputs']:
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message=output['text']
for assessment in response['assessments']:
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
ret_violation = SafetyViolation(
user_message=user_message,
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata
metadata=metadata,
)
except:
except Exception:
error_str = traceback.format_exc()
print(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
logger.error(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
#raise RuntimeError(f"Error running request for BedrockGaurdrails: {error_str}:")
return ret_violation
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
)
return None

View file

@ -4,21 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
import boto3
@json_schema_type
class BedrockShieldConfig(BaseModel):
class BedrockSafetyConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request."""
aws_profile: Optional[str] = Field(
default='default',
aws_profile: str = Field(
default="default",
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
)

View file

@ -41,10 +41,10 @@ def available_providers() -> List[ProviderSpec]:
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_id="bedrock_guardrails",
pip_packages=['boto3',],
adapter_id="bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.adapters.safety.bedrock",
config_class="llama_stack.providers.adapters.safety.bedrock.config.BedrockShieldConfig",
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
),
),
remote_provider_spec(