add register in meta reference

This commit is contained in:
Dinesh Yeduguru 2024-11-07 20:59:16 -08:00
parent 932b524449
commit 874206baeb

View file

@ -30,9 +30,9 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
self.available_shields = []
if config.llama_guard_shield:
self.available_shields.append(ShieldType.llama_guard.value)
self.available_shields.append(ShieldType.llama_guard)
if config.enable_prompt_guard:
self.available_shields.append(ShieldType.prompt_guard.value)
self.available_shields.append(ShieldType.prompt_guard)
async def initialize(self) -> None:
if self.config.enable_prompt_guard:
@ -43,7 +43,8 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
pass
async def register_shield(self, shield: Shield) -> None:
raise ValueError("Registering dynamic shields is not supported")
if shield.shield_type not in self.available_shields:
raise ValueError(f"Shield type {shield.shield_type} not supported")
async def run_shield(
self,
@ -79,14 +80,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
return RunShieldResponse(violation=violation)
def get_shield_impl(self, shield: Shield) -> ShieldBase:
if shield.shield_type == ShieldType.llama_guard.value:
if shield.shield_type == ShieldType.llama_guard:
cfg = self.config.llama_guard_shield
return LlamaGuardShield(
model=cfg.model,
inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories,
)
elif shield.shield_type == ShieldType.prompt_guard.value:
elif shield.shield_type == ShieldType.prompt_guard:
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
subtype = shield.params.get("prompt_guard_type", "injection")
if subtype == "injection":