forked from phoenix-oss/llama-stack-mirror
Some lightweight cleanup and renaming for bedrock safety adapter
This commit is contained in:
parent
a2465f3f9c
commit
f45705cd10
5 changed files with 76 additions and 78 deletions
|
@ -160,7 +160,7 @@ class StackConfigure(Subcommand):
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"> YAML configuration has been written to {run_config_file}.",
|
f"> YAML configuration has been written to `{run_config_file}`.",
|
||||||
color="blue",
|
color="blue",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -7,12 +7,12 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from .config import BedrockShieldConfig
|
from .config import BedrockSafetyConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: BedrockShieldConfig, _deps) -> Any:
|
async def get_adapter_impl(config: BedrockSafetyConfig, _deps) -> Any:
|
||||||
from .bedrock import BedrockShieldAdapter
|
from .bedrock import BedrockSafetyAdapter
|
||||||
|
|
||||||
impl = BedrockShieldAdapter(config)
|
impl = BedrockSafetyAdapter(config)
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
|
@ -5,44 +5,47 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, AsyncGenerator, Dict
|
|
||||||
from .config import BedrockShieldConfig
|
|
||||||
import traceback
|
import traceback
|
||||||
import asyncio
|
from typing import Any, Dict, List
|
||||||
from enum import Enum
|
|
||||||
from typing import List
|
from .config import BedrockSafetyConfig
|
||||||
from pydantic import BaseModel, validator
|
|
||||||
from llama_stack.apis.safety import * # noqa
|
from llama_stack.apis.safety import * # noqa
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
import boto3
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BedrockShieldAdapter(Safety):
|
|
||||||
def __init__(self, config: BedrockShieldConfig) -> None:
|
class BedrockSafetyAdapter(Safety):
|
||||||
|
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
try:
|
|
||||||
if not self.config.aws_profile:
|
if not self.config.aws_profile:
|
||||||
raise RuntimeError(f"Missing boto_client aws_profile in model info::{self.config}")
|
raise RuntimeError(
|
||||||
|
f"Missing boto_client aws_profile in model info::{self.config}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
print(f"initializing with profile --- > {self.config}::")
|
print(f"initializing with profile --- > {self.config}::")
|
||||||
self.boto_client_profile = self.config.aws_profile
|
self.boto_client_profile = self.config.aws_profile
|
||||||
self.boto_client = boto3.Session(profile_name=self.boto_client_profile).client('bedrock-runtime')
|
self.boto_client = boto3.Session(
|
||||||
|
profile_name=self.boto_client_profile
|
||||||
|
).client("bedrock-runtime")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
|
||||||
|
|
||||||
traceback.print_exc()
|
|
||||||
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
|
raise RuntimeError(f"Error initializing BedrockSafetyAdapter: {e}") from e
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_shield(self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None) -> RunShieldResponse:
|
async def run_shield(
|
||||||
""" This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
|
) -> RunShieldResponse:
|
||||||
|
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||||
```content = [
|
```content = [
|
||||||
{
|
{
|
||||||
"text": {
|
"text": {
|
||||||
|
@ -55,49 +58,52 @@ class BedrockShieldAdapter(Safety):
|
||||||
|
|
||||||
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
|
||||||
"""
|
"""
|
||||||
ret_violation = None
|
|
||||||
try:
|
try:
|
||||||
logger.debug(f"run_shield::{params}::messages={messages}")
|
logger.debug(f"run_shield::{params}::messages={messages}")
|
||||||
if not 'guardrailIdentifier' in params:
|
if "guardrailIdentifier" not in params:
|
||||||
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing GuardrailID in request")
|
raise RuntimeError(
|
||||||
|
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
|
||||||
|
)
|
||||||
|
|
||||||
if not 'guardrailVersion' in params:
|
if "guardrailVersion" not in params:
|
||||||
raise RuntimeError(f"Error running request for BedrockGaurdrails:Missing guardrailVersion in request")
|
raise RuntimeError(
|
||||||
|
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
|
||||||
|
)
|
||||||
|
|
||||||
#- convert the messages into format Bedrock expects
|
# - convert the messages into format Bedrock expects
|
||||||
content_messages = []
|
content_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content_messages.append({"text": {"text": message.content}})
|
content_messages.append({"text": {"text": message.content}})
|
||||||
logger.debug(f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:")
|
logger.debug(
|
||||||
|
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||||
|
)
|
||||||
|
|
||||||
response = self.boto_client.apply_guardrail(
|
response = self.boto_client.apply_guardrail(
|
||||||
guardrailIdentifier=params.get('guardrailIdentifier'),
|
guardrailIdentifier=params.get("guardrailIdentifier"),
|
||||||
guardrailVersion=params.get('guardrailVersion'),
|
guardrailVersion=params.get("guardrailVersion"),
|
||||||
source='OUTPUT', # or 'INPUT' depending on your use case
|
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||||
content=content_messages
|
content=content_messages,
|
||||||
)
|
)
|
||||||
logger.debug(f"run_shield:: response: {response}::")
|
logger.debug(f"run_shield:: response: {response}::")
|
||||||
if response['action'] == 'GUARDRAIL_INTERVENED':
|
if response["action"] == "GUARDRAIL_INTERVENED":
|
||||||
user_message=""
|
user_message = ""
|
||||||
metadata={}
|
metadata = {}
|
||||||
for output in response['outputs']:
|
for output in response["outputs"]:
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
user_message=output['text']
|
user_message = output["text"]
|
||||||
for assessment in response['assessments']:
|
for assessment in response["assessments"]:
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
metadata = dict(assessment)
|
metadata = dict(assessment)
|
||||||
ret_violation = SafetyViolation(
|
return SafetyViolation(
|
||||||
user_message=user_message,
|
user_message=user_message,
|
||||||
violation_level=ViolationLevel.ERROR,
|
violation_level=ViolationLevel.ERROR,
|
||||||
metadata=metadata
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
except:
|
except Exception:
|
||||||
error_str = traceback.format_exc()
|
error_str = traceback.format_exc()
|
||||||
print(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
|
logger.error(
|
||||||
logger.error(f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!")
|
f"Error in apply_guardrails:{error_str}:: RETURNING None !!!!!"
|
||||||
#raise RuntimeError(f"Error running request for BedrockGaurdrails: {error_str}:")
|
)
|
||||||
|
|
||||||
return ret_violation
|
|
||||||
|
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
|
@ -4,21 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
import boto3
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class BedrockSafetyConfig(BaseModel):
|
||||||
class BedrockShieldConfig(BaseModel):
|
|
||||||
"""Configuration information for a guardrail that you want to use in the request."""
|
"""Configuration information for a guardrail that you want to use in the request."""
|
||||||
|
|
||||||
aws_profile: Optional[str] = Field(
|
aws_profile: str = Field(
|
||||||
default='default',
|
default="default",
|
||||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -41,10 +41,10 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.safety,
|
api=Api.safety,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_id="bedrock_guardrails",
|
adapter_id="bedrock",
|
||||||
pip_packages=['boto3',],
|
pip_packages=["boto3"],
|
||||||
module="llama_stack.providers.adapters.safety.bedrock",
|
module="llama_stack.providers.adapters.safety.bedrock",
|
||||||
config_class="llama_stack.providers.adapters.safety.bedrock.config.BedrockShieldConfig",
|
config_class="llama_stack.providers.adapters.safety.bedrock.BedrockSafetyConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue