mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-16 14:57:20 +00:00
add bedrock shields support
This commit is contained in:
parent
ccd60dc29d
commit
7176338ca6
4 changed files with 116 additions and 18 deletions
|
@ -25,20 +25,33 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
|||
ShieldType.generic_content_shield.value,
|
||||
]
|
||||
|
||||
def _create_bedrock_client(config: BedrockSafetyConfig, name: str) :
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
|
||||
return boto3_session.client(name)
|
||||
|
||||
|
||||
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||
def __init__(self, config: BedrockSafetyConfig) -> None:
|
||||
if not config.aws_profile:
|
||||
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
|
||||
self.config = config
|
||||
self.registered_shields = []
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
print(f"initializing with profile --- > {self.config}")
|
||||
self.boto_client = boto3.Session(
|
||||
profile_name=self.config.aws_profile
|
||||
).client("bedrock-runtime")
|
||||
self.bedrock_runtime_client = _create_bedrock_client(self.config, "bedrock-runtime")
|
||||
self.bedrock_client = _create_bedrock_client(self.config, "bedrock")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
|
||||
|
||||
|
@ -49,12 +62,18 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
raise NotImplementedError(
|
||||
"""
|
||||
`list_shields` not implemented; this should read all guardrails from
|
||||
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
|
||||
"""
|
||||
)
|
||||
response = self.bedrock_client.list_guardrails()
|
||||
shields = []
|
||||
for guardrail in response["guardrails"]:
|
||||
# populate the shield def with the guardrail id and version
|
||||
shield_def = ShieldDef(
|
||||
identifier=guardrail["id"],
|
||||
shield_type=ShieldType.generic_content_shield.value,
|
||||
params={"guardrailIdentifier": guardrail["id"], "guardrailVersion": guardrail["version"]},
|
||||
)
|
||||
shields.append(shield_def)
|
||||
return shields
|
||||
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
|
@ -88,7 +107,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
f"run_shield::final:messages::{json.dumps(content_messages, indent=2)}:"
|
||||
)
|
||||
|
||||
response = self.boto_client.apply_guardrail(
|
||||
response = self.bedrock_runtime_client.apply_guardrail(
|
||||
guardrailIdentifier=shield_params["guardrailIdentifier"],
|
||||
guardrailVersion=shield_params["guardrailVersion"],
|
||||
source="OUTPUT", # or 'INPUT' depending on your use case
|
||||
|
|
|
@ -5,12 +5,29 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from typing import Optional
|
||||
|
||||
class BedrockSafetyConfig(BaseModel):
|
||||
"""Configuration information for a guardrail that you want to use in the request."""
|
||||
|
||||
aws_profile: str = Field(
|
||||
aws_access_key_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||
)
|
||||
aws_secret_access_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||
)
|
||||
aws_session_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||
)
|
||||
region_name: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||
)
|
||||
profile_name: str = Field(
|
||||
default="default",
|
||||
description="The profile on the machine having valid aws credentials. This will ensure separation of creation to invocation",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue