safety API cleanup part 1

Sample adapter implementation for Bedrock implementation of Guardrails
This commit is contained in:
Ashwin Bharambe 2024-09-20 10:57:26 -07:00
parent 90a59fd89b
commit 93e4ef3829
7 changed files with 130 additions and 78 deletions

View file

@ -23,13 +23,6 @@ from .shields import (
)
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model.descriptor())
return model_dir
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
self.config = config
@ -50,16 +43,17 @@ class MetaReferenceSafetyImpl(Safety):
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def run_shields(
async def run_shield(
self,
shield_type: ShieldType,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in shields]
assert shield_type in [
"llama_guard",
"prompt_guard",
], f"Unknown shield {shield_type}"
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
return RunShieldResponse(responses=responses)
raise NotImplementedError()
def shield_type_equals(a: ShieldType, b: ShieldType):