diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index f36515471..e22fb1130 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -25,7 +25,8 @@ BEDROCK_SUPPORTED_SHIELDS = [ ShieldType.generic_content_shield.value, ] -def _create_bedrock_client(config: BedrockSafetyConfig, name: str) : + +def _create_bedrock_client(config: BedrockSafetyConfig, name: str): session_args = { k: v for k, v in dict( @@ -50,7 +51,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): async def initialize(self) -> None: try: - self.bedrock_runtime_client = _create_bedrock_client(self.config, "bedrock-runtime") + self.bedrock_runtime_client = _create_bedrock_client( + self.config, "bedrock-runtime" + ) self.bedrock_client = _create_bedrock_client(self.config, "bedrock") except Exception as e: raise RuntimeError("Error initializing BedrockSafetyAdapter") from e @@ -69,12 +72,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): shield_def = ShieldDef( identifier=guardrail["id"], shield_type=ShieldType.generic_content_shield.value, - params={"guardrailIdentifier": guardrail["id"], "guardrailVersion": guardrail["version"]}, + params={ + "guardrailIdentifier": guardrail["id"], + "guardrailVersion": guardrail["version"], + }, ) shields.append(shield_def) return shields - async def run_shield( self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None ) -> RunShieldResponse: diff --git a/llama_stack/providers/adapters/safety/bedrock/config.py b/llama_stack/providers/adapters/safety/bedrock/config.py index 7a01d08fb..afa83f366 100644 --- a/llama_stack/providers/adapters/safety/bedrock/config.py +++ b/llama_stack/providers/adapters/safety/bedrock/config.py @@ -4,9 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel, Field from typing import Optional +from pydantic import BaseModel, Field + + class BedrockSafetyConfig(BaseModel): """Configuration information for a guardrail that you want to use in the request."""