Resource oriented design for shields (#399)

* init

* working bedrock tests

* bedrock test for inference fixes

* use env vars for bedrock guardrail vars

* add register in meta reference

* use correct shield impl in meta ref

* dont add together fixture

* right naming

* minor updates

* improved registration flow

* address feedback

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-08 12:16:11 -08:00 committed by GitHub
parent 7ee9f8d8ac
commit d800a16acd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 262 additions and 124 deletions

View file

@ -21,6 +21,7 @@ from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard]
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
@ -30,9 +31,9 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
self.available_shields = []
if config.llama_guard_shield:
self.available_shields.append(ShieldType.llama_guard.value)
self.available_shields.append(ShieldType.llama_guard)
if config.enable_prompt_guard:
self.available_shields.append(ShieldType.prompt_guard.value)
self.available_shields.append(ShieldType.prompt_guard)
async def initialize(self) -> None:
if self.config.enable_prompt_guard:
@ -42,30 +43,21 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
return [
ShieldDef(
identifier=shield_type,
shield_type=shield_type,
params={},
)
for shield_type in self.available_shields
]
async def register_shield(self, shield: Shield) -> None:
if shield.shield_type not in self.available_shields:
raise ValueError(f"Shield type {shield.shield_type} not supported")
async def run_shield(
self,
identifier: str,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {identifier}")
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
shield = self.get_shield_impl(shield_def)
shield_impl = self.get_shield_impl(shield)
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
@ -74,13 +66,16 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
messages[0] = UserMessage(content=messages[0].content)
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
res = await shield_impl.run(messages)
violation = None
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
if (
res.is_violation
and shield_impl.on_violation_action != OnViolationAction.IGNORE
):
violation = SafetyViolation(
violation_level=(
ViolationLevel.ERROR
if shield.on_violation_action == OnViolationAction.RAISE
if shield_impl.on_violation_action == OnViolationAction.RAISE
else ViolationLevel.WARN
),
user_message=res.violation_return_message,
@ -91,15 +86,15 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
return RunShieldResponse(violation=violation)
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
if shield.shield_type == ShieldType.llama_guard.value:
def get_shield_impl(self, shield: Shield) -> ShieldBase:
if shield.shield_type == ShieldType.llama_guard:
cfg = self.config.llama_guard_shield
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
)
elif shield.shield_type == ShieldType.prompt_guard.value:
elif shield.shield_type == ShieldType.prompt_guard:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
subtype = shield.params.get("prompt_guard_type", "injection")
if subtype == "injection":