[aws-bedrock] Support for Bedrock Safety adapter (#96)

This commit is contained in:
rsgrewal-aws 2024-09-24 19:16:55 -07:00 committed by GitHub
parent b85d675c6f
commit 059e50b389
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 154 additions and 0 deletions

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import BedrockShieldConfig
async def get_adapter_impl(config: BedrockShieldConfig, _deps) -> Any:
from .bedrock import BedrockShieldAdapter
impl = BedrockShieldAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, Dict
from .config import BedrockShieldConfig
import traceback
import asyncio
from enum import Enum
from typing import List
from pydantic import BaseModel, validator
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
import boto3
import json
import logging
logger = logging.getLogger(__name__)
class BedrockShieldAdapter(Safety):
def __init__(self, config: BedrockShieldConfig) -> None:
self.config = config
async def initialize(self) -> None:
try:
if not self.config.aws_profile:
raise RuntimeError(f"Missing boto_client aws_profile in model info::{self.config}")
print(f"initializing with profile --- > {self.config}::")
self.boto_client_profile = self.config.aws_profile
self.boto_client = boto3.Session(profile_name=self.boto_client_profile).client('bedrock-runtime')
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
async def shutdown(self) -> None:
pass
async def run_shield(self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None) -> RunShieldResponse:
""" This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
{
"text": {
"text": "Is the AB503 Product a better investment than the S&P 500?"
}
}
]```
However the incoming messages are of this type UserMessage(content=....) coming from
https://github.com/meta-llama/llama-models/blob/main/models/llama3/api/datatypes.py
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
ret_violation = None
try:
logger.debug(f"run_shield::{params}::messages={messages}")
if not 'guardrailIdentifier' in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing GuardrailID in request")
if not 'guardrailVersion' in params:
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing guardrailVersion in request")
#- convert the messages into format Bedrock expects
content_messages = []
for message in messages:
content_messages.append({"text": {"text": message.content}})
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
response = self.boto_client.apply_guardrail(
guardrailIdentifier=params.get('guardrailIdentifier'),
guardrailVersion=params.get('guardrailVersion'),
source='OUTPUT', # or 'INPUT' depending on your use case
content=content_messages
)
logger.debug(f"run_shield:: response: {response}::")
if response['action'] == 'GUARDRAIL_INTERVENED':
user_message=""
metadata={}
for output in response['outputs']:
# guardrails returns a list - however for this implementation we will leverage the last values
user_message=output['text']
for assessment in response['assessments']:
# guardrails returns a list - however for this implementation we will leverage the last values
metadata = dict(assessment)
ret_violation = SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata
)
except:
error_str = traceback.format_exc()
print(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
logger.error(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
#raise RuntimeError(f"Error running request for BedrockGaurdrails: {error_str}:")
return ret_violation

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
import boto3
@json_schema_type
class BedrockShieldConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request."""
aws_profile: Optional[str] = Field(
default='default',
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
)

View file

@ -40,6 +40,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.safety.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_id="bedrock_guardrails",
pip_packages=['boto3',],
module="llama_stack.providers.adapters.safety.bedrock",
config_class="llama_stack.providers.adapters.safety.bedrock.config.BedrockShieldConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(