linter fixes

This commit is contained in:
Dinesh Yeduguru 2024-11-05 10:10:24 -08:00
parent 7176338ca6
commit a4fd91fe51
2 changed files with 12 additions and 5 deletions

View file

@ -25,7 +25,8 @@ BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield.value, ShieldType.generic_content_shield.value,
] ]
def _create_bedrock_client(config: BedrockSafetyConfig, name: str) :
def _create_bedrock_client(config: BedrockSafetyConfig, name: str):
session_args = { session_args = {
k: v k: v
for k, v in dict( for k, v in dict(
@ -50,7 +51,9 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
async def initialize(self) -> None: async def initialize(self) -> None:
try: try:
self.bedrock_runtime_client = _create_bedrock_client(self.config, "bedrock-runtime") self.bedrock_runtime_client = _create_bedrock_client(
self.config, "bedrock-runtime"
)
self.bedrock_client = _create_bedrock_client(self.config, "bedrock") self.bedrock_client = _create_bedrock_client(self.config, "bedrock")
except Exception as e: except Exception as e:
raise RuntimeError("Error initializing BedrockSafetyAdapter") from e raise RuntimeError("Error initializing BedrockSafetyAdapter") from e
@ -69,12 +72,14 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
shield_def = ShieldDef( shield_def = ShieldDef(
identifier=guardrail["id"], identifier=guardrail["id"],
shield_type=ShieldType.generic_content_shield.value, shield_type=ShieldType.generic_content_shield.value,
params={"guardrailIdentifier": guardrail["id"], "guardrailVersion": guardrail["version"]}, params={
"guardrailIdentifier": guardrail["id"],
"guardrailVersion": guardrail["version"],
},
) )
shields.append(shield_def) shields.append(shield_def)
return shields return shields
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None self, shield_type: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse: ) -> RunShieldResponse:

View file

@ -4,9 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field
class BedrockSafetyConfig(BaseModel): class BedrockSafetyConfig(BaseModel):
"""Configuration information for a guardrail that you want to use in the request.""" """Configuration information for a guardrail that you want to use in the request."""