add register in meta reference

This commit is contained in:
Dinesh Yeduguru 2024-11-07 20:59:16 -08:00
parent 932b524449
commit 874206baeb

View file

@ -30,9 +30,9 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
self.available_shields = [] self.available_shields = []
if config.llama_guard_shield: if config.llama_guard_shield:
self.available_shields.append(ShieldType.llama_guard.value) self.available_shields.append(ShieldType.llama_guard)
if config.enable_prompt_guard: if config.enable_prompt_guard:
self.available_shields.append(ShieldType.prompt_guard.value) self.available_shields.append(ShieldType.prompt_guard)
async def initialize(self) -> None: async def initialize(self) -> None:
if self.config.enable_prompt_guard: if self.config.enable_prompt_guard:
@ -43,7 +43,8 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
raise ValueError("Registering dynamic shields is not supported") if shield.shield_type not in self.available_shields:
raise ValueError(f"Shield type {shield.shield_type} not supported")
async def run_shield( async def run_shield(
self, self,
@ -79,14 +80,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
return RunShieldResponse(violation=violation) return RunShieldResponse(violation=violation)
def get_shield_impl(self, shield: Shield) -> ShieldBase: def get_shield_impl(self, shield: Shield) -> ShieldBase:
if shield.shield_type == ShieldType.llama_guard.value: if shield.shield_type == ShieldType.llama_guard:
cfg = self.config.llama_guard_shield cfg = self.config.llama_guard_shield
return LlamaGuardShield( return LlamaGuardShield(
model=cfg.model, model=cfg.model,
inference_api=self.inference_api, inference_api=self.inference_api,
excluded_categories=cfg.excluded_categories, excluded_categories=cfg.excluded_categories,
) )
elif shield.shield_type == ShieldType.prompt_guard.value: elif shield.shield_type == ShieldType.prompt_guard:
model_dir = model_local_dir(PROMPT_GUARD_MODEL) model_dir = model_local_dir(PROMPT_GUARD_MODEL)
subtype = shield.params.get("prompt_guard_type", "injection") subtype = shield.params.get("prompt_guard_type", "injection")
if subtype == "injection": if subtype == "injection":