mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 07:39:38 +00:00
working bedrock safety
This commit is contained in:
parent
a4fd91fe51
commit
df76c9b484
7 changed files with 63 additions and 44 deletions
|
@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
|
|||
|
||||
|
||||
class ShieldStore(Protocol):
|
||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
async def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
@ -48,5 +48,5 @@ class Safety(Protocol):
|
|||
|
||||
@webmethod(route="/safety/run_shield")
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse: ...
|
||||
|
|
|
@ -46,7 +46,7 @@ class Shields(Protocol):
|
|||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/get", method="GET")
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ...
|
||||
|
||||
@webmethod(route="/shields/register", method="POST")
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||
|
|
|
@ -154,12 +154,12 @@ class SafetyRouter(Safety):
|
|||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield_type: str,
|
||||
identifier: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
||||
shield_type=shield_type,
|
||||
return await self.routing_table.get_provider_impl(identifier).run_shield(
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
params=params,
|
||||
)
|
||||
|
|
|
@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return await self.get_all_with_type("shield")
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
return await self.get_object_by_identifier(shield_type)
|
||||
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
await self.register_object(shield)
|
||||
|
|
|
@ -16,6 +16,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
import os
|
||||
|
||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||
|
||||
|
||||
|
@ -443,8 +445,10 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|||
retries_config = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
total_max_attempts=config.total_max_attempts,
|
||||
mode=config.retry_mode,
|
||||
total_max_attempts=os.environ.get(
|
||||
"AWS_MAX_ATTEMPTS", config.total_max_attempts
|
||||
),
|
||||
mode=os.environ.get("AWS_RETRY_MODE", config.retry_mode),
|
||||
).items()
|
||||
if v is not None
|
||||
}
|
||||
|
@ -452,7 +456,7 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|||
config_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
region_name=config.region_name,
|
||||
region_name=os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||
retries=retries_config if retries_config else None,
|
||||
connect_timeout=config.connect_timeout,
|
||||
read_timeout=config.read_timeout,
|
||||
|
@ -463,17 +467,21 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
|||
boto3_config = Config(**config_args)
|
||||
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
"aws_access_key_id": os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
|
||||
),
|
||||
"aws_secret_access_key": os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
|
||||
),
|
||||
"aws_session_token": os.environ.get(
|
||||
"AWS_SESSION_TOKEN", config.aws_session_token
|
||||
),
|
||||
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
# Remove None values
|
||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
@ -27,20 +28,25 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
|||
|
||||
|
||||
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
|
||||
# Use environment variables by default, fall back to config values
|
||||
session_args = {
|
||||
k: v
|
||||
for k, v in dict(
|
||||
aws_access_key_id=config.aws_access_key_id,
|
||||
aws_secret_access_key=config.aws_secret_access_key,
|
||||
aws_session_token=config.aws_session_token,
|
||||
region_name=config.region_name,
|
||||
profile_name=config.profile_name,
|
||||
).items()
|
||||
if v is not None
|
||||
"aws_access_key_id": os.environ.get(
|
||||
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
|
||||
),
|
||||
"aws_secret_access_key": os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
|
||||
),
|
||||
"aws_session_token": os.environ.get(
|
||||
"AWS_SESSION_TOKEN", config.aws_session_token
|
||||
),
|
||||
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
|
||||
}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
# Remove None values
|
||||
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||
|
||||
boto3_session = boto3.session.Session(**session_args)
|
||||
return boto3_session.client(name)
|
||||
|
||||
|
||||
|
@ -77,15 +83,16 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
"guardrailVersion": guardrail["version"],
|
||||
},
|
||||
)
|
||||
self.registered_shields.append(shield_def)
|
||||
shields.append(shield_def)
|
||||
return shields
|
||||
|
||||
async def run_shield(
|
||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(shield_type)
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {shield_type}")
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
|
||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||
```content = [
|
||||
|
@ -128,10 +135,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
|||
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||
metadata = dict(assessment)
|
||||
|
||||
return SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
return RunShieldResponse(
|
||||
violations=[
|
||||
SafetyViolation(
|
||||
user_message=user_message,
|
||||
violation_level=ViolationLevel.ERROR,
|
||||
metadata=metadata,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
return RunShieldResponse(violations=[])
|
||||
|
|
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
|||
self.output_shields = output_shields
|
||||
|
||||
async def run_multiple_shields(
|
||||
self, messages: List[Message], shield_types: List[str]
|
||||
self, messages: List[Message], identifiers: List[str]
|
||||
) -> None:
|
||||
responses = await asyncio.gather(
|
||||
*[
|
||||
self.safety_api.run_shield(
|
||||
shield_type=shield_type,
|
||||
identifier=identifier,
|
||||
messages=messages,
|
||||
)
|
||||
for shield_type in shield_types
|
||||
for identifier in identifiers
|
||||
]
|
||||
)
|
||||
for shield_type, response in zip(shield_types, responses):
|
||||
for identifier, response in zip(identifiers, responses):
|
||||
if not response.violation:
|
||||
continue
|
||||
|
||||
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
|||
raise SafetyException(violation)
|
||||
elif violation.violation_level == ViolationLevel.WARN:
|
||||
cprint(
|
||||
f"[Warn]{shield_type} raised a warning",
|
||||
f"[Warn]{identifier} raised a warning",
|
||||
color="red",
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue