Some lightweight cleanup and renaming for bedrock safety adapter

This commit is contained in:
Ashwin Bharambe 2024-09-24 19:27:03 -07:00
parent a2465f3f9c
commit f45705cd10
5 changed files with 76 additions and 78 deletions

View file

@ -5,99 +5,105 @@
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict
from .config import BedrockShieldConfig
import traceback
import asyncio
from enum import Enum
from typing import List
from pydantic import BaseModel, validator
from typing import Any, Dict, List
from .config import BedrockSafetyConfig
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
import boto3
import json
import logging
import boto3
logger = logging.getLogger(__name__)
class BedrockShieldAdapter(Safety):
def __init__(self, config: BedrockShieldConfig) -> None:
class BedrockSafetyAdapter(Safety):
def __init__(self, config: BedrockSafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
if not self.config.aws_profile:
raise RuntimeError(
f"Missing boto_client aws_profile in model info::{self.config}"
)
try:
if not self.config.aws_profile:
raise RuntimeError(f"Missing boto_client aws_profile in model info::{self.config}")
print(f"initializing with profile --- > {self.config}::")
self.boto_client_profile = self.config.aws_profile
self.boto_client = boto3.Session(profile_name=self.boto_client_profile).client('bedrock-runtime')
self.boto_client = boto3.Session(
profile_name=self.boto_client_profile
).client("bedrock-runtime")
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
async def shutdown(self) -> None:
pass
async def run_shield(self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None) -> RunShieldResponse:
""" This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
ret_violation = None
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if not 'guardrailIdentifier' in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing GuardrailID in request")
if not 'guardrailVersion' in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing guardrailVersion in request")
#- convert the messages into format Bedrock expects
if "guardrailIdentifier" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in params:
raise RuntimeError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
# - convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
logger.debug(
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
)
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get('guardrailIdentifier'),
guardrailVersion=params.get('guardrailVersion'),
source='OUTPUT', # or 'INPUT' depending on your use case
content=content_messages
guardrailIdentifier=params.get("guardrailIdentifier"),
guardrailVersion=params.get("guardrailVersion"),
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,
)
logger.debug(f"run_shield:: response: {response}::")
if response['action'] == 'GUARDRAIL_INTERVENED':
user_message=""
metadata={}
for output in response['outputs']:
if response["action"] == "GUARDRAIL_INTERVENED":
user_message = ""
metadata = {}
for output in response["outputs"]:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message=output['text']
for assessment in response['assessments']:
user_message = output["text"]
for assessment in response["assessments"]:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
ret_violation = SafetyViolation(
user_message=user_message,
metadata = dict(assessment)
return SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata
metadata=metadata,
)
except:
except Exception:
error_str = traceback.format_exc()
print(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
logger.error(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
#raise RuntimeError(f"Error running request for BedrockGaurdrails: {error_str}:")
return ret_violation
logger.error(
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
)
return None