Update bedrock

This commit is contained in:
Ashwin Bharambe 2024-10-10 10:19:06 -07:00
parent fe0dabe596
commit a33cafc2fe
2 changed files with 63 additions and 60 deletions

View file

@ -13,6 +13,7 @@ import boto3
from llama_stack.apis.safety import * # noqa
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import BedrockSafetyConfig
@ -25,7 +26,7 @@ BEDROCK_SUPPORTED_SHIELDS = [
]
class BedrockSafetyAdapter(Safety):
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: BedrockSafetyConfig) -> None:
if not config.aws_profile:
raise ValueError(f"Missing boto_client aws_profile in model info::{config}")
@ -45,19 +46,15 @@ class BedrockSafetyAdapter(Safety):
pass
async def register_shield(self, shield: ShieldDef) -> None:
if shield.type not in BEDROCK_SUPPORTED_SHIELDS:
raise ValueError(f"Unsupported safety shield type: {shield.type}")
raise ValueError("Registering dynamic shields is not supported")
shield_params = shield.params
if "guardrailIdentifier" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing GuardrailID in request"
)
if "guardrailVersion" not in shield_params:
raise ValueError(
"Error running request for BedrockGaurdrails:Missing guardrailVersion in request"
)
async def list_shields(self) -> List[ShieldDef]:
raise NotImplementedError(
"""
`list_shields` not implemented; this should read all guardrails from
bedrock and populate guardrailId and guardrailVersion in the ShieldDef.
"""
)
async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None