forked from phoenix-oss/llama-stack-mirror
Resource oriented design for shields (#399)
* init * working bedrock tests * bedrock test for inference fixes * use env vars for bedrock guardrail vars * add register in meta reference * use correct shield impl in meta ref * dont add together fixture * right naming * minor updates * improved registration flow * address feedback --------- Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
parent
7ee9f8d8ac
commit
d800a16acd
20 changed files with 262 additions and 124 deletions
|
@ -21,6 +21,7 @@ from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield
|
|||
|
||||
|
||||
PROMPT_GUARD_MODEL = "Prompt-Guard-86M"
|
||||
SUPPORTED_SHIELDS = [ShieldType.llama_guard, ShieldType.prompt_guard]
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||
|
@ -30,9 +31,9 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
self.available_shields = []
|
||||
if config.llama_guard_shield:
|
||||
self.available_shields.append(ShieldType.llama_guard.value)
|
||||
self.available_shields.append(ShieldType.llama_guard)
|
||||
if config.enable_prompt_guard:
|
||||
self.available_shields.append(ShieldType.prompt_guard.value)
|
||||
self.available_shields.append(ShieldType.prompt_guard)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if self.config.enable_prompt_guard:
|
||||
|
@ -42,30 +43,21 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
raise ValueError("Registering dynamic shields is not supported")
|
||||
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
return [
|
||||
ShieldDef(
|
||||
identifier=shield_type,
|
||||
shield_type=shield_type,
|
||||
params={},
|
||||
)
|
||||
for shield_type in self.available_shields
|
||||
]
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
if shield.shield_type not in self.available_shields:
|
||||
raise ValueError(f"Shield type {shield.shield_type} not supported")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
identifier: str,
|
||||
shield_id: str,
|
||||
messages: List[Message],
|
||||
params: Dict[str, Any] = None,
|
||||
) -> RunShieldResponse:
|
||||
shield_def = await self.shield_store.get_shield(identifier)
|
||||
if not shield_def:
|
||||
raise ValueError(f"Unknown shield {identifier}")
|
||||
shield = await self.shield_store.get_shield(shield_id)
|
||||
if not shield:
|
||||
raise ValueError(f"Shield {shield_id} not found")
|
||||
|
||||
shield = self.get_shield_impl(shield_def)
|
||||
shield_impl = self.get_shield_impl(shield)
|
||||
|
||||
messages = messages.copy()
|
||||
# some shields like llama-guard require the first message to be a user message
|
||||
|
@ -74,13 +66,16 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
messages[0] = UserMessage(content=messages[0].content)
|
||||
|
||||
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
|
||||
res = await shield.run(messages)
|
||||
res = await shield_impl.run(messages)
|
||||
violation = None
|
||||
if res.is_violation and shield.on_violation_action != OnViolationAction.IGNORE:
|
||||
if (
|
||||
res.is_violation
|
||||
and shield_impl.on_violation_action != OnViolationAction.IGNORE
|
||||
):
|
||||
violation = SafetyViolation(
|
||||
violation_level=(
|
||||
ViolationLevel.ERROR
|
||||
if shield.on_violation_action == OnViolationAction.RAISE
|
||||
if shield_impl.on_violation_action == OnViolationAction.RAISE
|
||||
else ViolationLevel.WARN
|
||||
),
|
||||
user_message=res.violation_return_message,
|
||||
|
@ -91,15 +86,15 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
|
||||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||
if shield.shield_type == ShieldType.llama_guard.value:
|
||||
def get_shield_impl(self, shield: Shield) -> ShieldBase:
|
||||
if shield.shield_type == ShieldType.llama_guard:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
)
|
||||
elif shield.shield_type == ShieldType.prompt_guard.value:
|
||||
elif shield.shield_type == ShieldType.prompt_guard:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue