mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-30 23:51:00 +00:00
working bedrock safety
This commit is contained in:
parent
a4fd91fe51
commit
df76c9b484
7 changed files with 63 additions and 44 deletions
|
@ -39,7 +39,7 @@ class RunShieldResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ShieldStore(Protocol):
|
class ShieldStore(Protocol):
|
||||||
def get_shield(self, identifier: str) -> ShieldDef: ...
|
async def get_shield(self, identifier: str) -> ShieldDef: ...
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
@ -48,5 +48,5 @@ class Safety(Protocol):
|
||||||
|
|
||||||
@webmethod(route="/safety/run_shield")
|
@webmethod(route="/safety/run_shield")
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse: ...
|
) -> RunShieldResponse: ...
|
||||||
|
|
|
@ -46,7 +46,7 @@ class Shields(Protocol):
|
||||||
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
async def list_shields(self) -> List[ShieldDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/get", method="GET")
|
@webmethod(route="/shields/get", method="GET")
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: ...
|
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ...
|
||||||
|
|
||||||
@webmethod(route="/shields/register", method="POST")
|
@webmethod(route="/shields/register", method="POST")
|
||||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ...
|
||||||
|
|
|
@ -154,12 +154,12 @@ class SafetyRouter(Safety):
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield_type: str,
|
identifier: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
params: Dict[str, Any] = None,
|
params: Dict[str, Any] = None,
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
return await self.routing_table.get_provider_impl(shield_type).run_shield(
|
return await self.routing_table.get_provider_impl(identifier).run_shield(
|
||||||
shield_type=shield_type,
|
identifier=identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
|
@ -204,8 +204,8 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||||
async def list_shields(self) -> List[ShieldDef]:
|
async def list_shields(self) -> List[ShieldDef]:
|
||||||
return await self.get_all_with_type("shield")
|
return await self.get_all_with_type("shield")
|
||||||
|
|
||||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]:
|
||||||
return await self.get_object_by_identifier(shield_type)
|
return await self.get_object_by_identifier(identifier)
|
||||||
|
|
||||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||||
await self.register_object(shield)
|
await self.register_object(shield)
|
||||||
|
|
|
@ -16,6 +16,8 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
import os
|
||||||
|
|
||||||
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -443,8 +445,10 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||||
retries_config = {
|
retries_config = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
total_max_attempts=config.total_max_attempts,
|
total_max_attempts=os.environ.get(
|
||||||
mode=config.retry_mode,
|
"AWS_MAX_ATTEMPTS", config.total_max_attempts
|
||||||
|
),
|
||||||
|
mode=os.environ.get("AWS_RETRY_MODE", config.retry_mode),
|
||||||
).items()
|
).items()
|
||||||
if v is not None
|
if v is not None
|
||||||
}
|
}
|
||||||
|
@ -452,7 +456,7 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||||
config_args = {
|
config_args = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in dict(
|
for k, v in dict(
|
||||||
region_name=config.region_name,
|
region_name=os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||||
retries=retries_config if retries_config else None,
|
retries=retries_config if retries_config else None,
|
||||||
connect_timeout=config.connect_timeout,
|
connect_timeout=config.connect_timeout,
|
||||||
read_timeout=config.read_timeout,
|
read_timeout=config.read_timeout,
|
||||||
|
@ -463,17 +467,21 @@ def _create_bedrock_client(config: BedrockConfig) -> BaseClient:
|
||||||
boto3_config = Config(**config_args)
|
boto3_config = Config(**config_args)
|
||||||
|
|
||||||
session_args = {
|
session_args = {
|
||||||
k: v
|
"aws_access_key_id": os.environ.get(
|
||||||
for k, v in dict(
|
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
|
||||||
aws_access_key_id=config.aws_access_key_id,
|
),
|
||||||
aws_secret_access_key=config.aws_secret_access_key,
|
"aws_secret_access_key": os.environ.get(
|
||||||
aws_session_token=config.aws_session_token,
|
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
|
||||||
region_name=config.region_name,
|
),
|
||||||
profile_name=config.profile_name,
|
"aws_session_token": os.environ.get(
|
||||||
).items()
|
"AWS_SESSION_TOKEN", config.aws_session_token
|
||||||
if v is not None
|
),
|
||||||
|
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||||
|
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
|
||||||
}
|
}
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
# Remove None values
|
||||||
|
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||||
|
|
||||||
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
return boto3_session.client("bedrock-runtime", config=boto3_config)
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
@ -27,20 +28,25 @@ BEDROCK_SUPPORTED_SHIELDS = [
|
||||||
|
|
||||||
|
|
||||||
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
|
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
|
||||||
|
# Use environment variables by default, fall back to config values
|
||||||
session_args = {
|
session_args = {
|
||||||
k: v
|
"aws_access_key_id": os.environ.get(
|
||||||
for k, v in dict(
|
"AWS_ACCESS_KEY_ID", config.aws_access_key_id
|
||||||
aws_access_key_id=config.aws_access_key_id,
|
),
|
||||||
aws_secret_access_key=config.aws_secret_access_key,
|
"aws_secret_access_key": os.environ.get(
|
||||||
aws_session_token=config.aws_session_token,
|
"AWS_SECRET_ACCESS_KEY", config.aws_secret_access_key
|
||||||
region_name=config.region_name,
|
),
|
||||||
profile_name=config.profile_name,
|
"aws_session_token": os.environ.get(
|
||||||
).items()
|
"AWS_SESSION_TOKEN", config.aws_session_token
|
||||||
if v is not None
|
),
|
||||||
|
"region_name": os.environ.get("AWS_DEFAULT_REGION", config.region_name),
|
||||||
|
"profile_name": os.environ.get("AWS_PROFILE", config.profile_name),
|
||||||
}
|
}
|
||||||
|
|
||||||
boto3_session = boto3.session.Session(**session_args)
|
# Remove None values
|
||||||
|
session_args = {k: v for k, v in session_args.items() if v is not None}
|
||||||
|
|
||||||
|
boto3_session = boto3.session.Session(**session_args)
|
||||||
return boto3_session.client(name)
|
return boto3_session.client(name)
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,15 +83,16 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
"guardrailVersion": guardrail["version"],
|
"guardrailVersion": guardrail["version"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.registered_shields.append(shield_def)
|
||||||
shields.append(shield_def)
|
shields.append(shield_def)
|
||||||
return shields
|
return shields
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
|
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
shield_def = await self.shield_store.get_shield(shield_type)
|
shield_def = await self.shield_store.get_shield(identifier)
|
||||||
if not shield_def:
|
if not shield_def:
|
||||||
raise ValueError(f"Unknown shield {shield_type}")
|
raise ValueError(f"Unknown shield {identifier}")
|
||||||
|
|
||||||
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
|
||||||
```content = [
|
```content = [
|
||||||
|
@ -128,10 +135,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
|
||||||
# guardrails returns a list - however for this implementation we will leverage the last values
|
# guardrails returns a list - however for this implementation we will leverage the last values
|
||||||
metadata = dict(assessment)
|
metadata = dict(assessment)
|
||||||
|
|
||||||
return SafetyViolation(
|
return RunShieldResponse(
|
||||||
user_message=user_message,
|
violations=[
|
||||||
violation_level=ViolationLevel.ERROR,
|
SafetyViolation(
|
||||||
metadata=metadata,
|
user_message=user_message,
|
||||||
|
violation_level=ViolationLevel.ERROR,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return RunShieldResponse(violations=[])
|
||||||
|
|
|
@ -32,18 +32,18 @@ class ShieldRunnerMixin:
|
||||||
self.output_shields = output_shields
|
self.output_shields = output_shields
|
||||||
|
|
||||||
async def run_multiple_shields(
|
async def run_multiple_shields(
|
||||||
self, messages: List[Message], shield_types: List[str]
|
self, messages: List[Message], identifiers: List[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
responses = await asyncio.gather(
|
responses = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.safety_api.run_shield(
|
self.safety_api.run_shield(
|
||||||
shield_type=shield_type,
|
identifier=identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
)
|
||||||
for shield_type in shield_types
|
for identifier in identifiers
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for shield_type, response in zip(shield_types, responses):
|
for identifier, response in zip(identifiers, responses):
|
||||||
if not response.violation:
|
if not response.violation:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -52,6 +52,6 @@ class ShieldRunnerMixin:
|
||||||
raise SafetyException(violation)
|
raise SafetyException(violation)
|
||||||
elif violation.violation_level == ViolationLevel.WARN:
|
elif violation.violation_level == ViolationLevel.WARN:
|
||||||
cprint(
|
cprint(
|
||||||
f"[Warn]{shield_type} raised a warning",
|
f"[Warn]{identifier} raised a warning",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue