From 874206baebbcf39d3b5b742ea646f94dd1863de1 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 7 Nov 2024 20:59:16 -0800 Subject: [PATCH] add register in meta reference --- .../providers/inline/safety/meta_reference/safety.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llama_stack/providers/inline/safety/meta_reference/safety.py b/llama_stack/providers/inline/safety/meta_reference/safety.py index 9093dcad6..ce7eaf591 100644 --- a/llama_stack/providers/inline/safety/meta_reference/safety.py +++ b/llama_stack/providers/inline/safety/meta_reference/safety.py @@ -30,9 +30,9 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): self.available_shields = [] if config.llama_guard_shield: - self.available_shields.append(ShieldType.llama_guard.value) + self.available_shields.append(ShieldType.llama_guard) if config.enable_prompt_guard: - self.available_shields.append(ShieldType.prompt_guard.value) + self.available_shields.append(ShieldType.prompt_guard) async def initialize(self) -> None: if self.config.enable_prompt_guard: @@ -43,7 +43,8 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): pass async def register_shield(self, shield: Shield) -> None: - raise ValueError("Registering dynamic shields is not supported") + if shield.shield_type not in self.available_shields: + raise ValueError(f"Shield type {shield.shield_type} not supported") async def run_shield( self, @@ -79,14 +80,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): return RunShieldResponse(violation=violation) def get_shield_impl(self, shield: Shield) -> ShieldBase: - if shield.shield_type == ShieldType.llama_guard.value: + if shield.shield_type == ShieldType.llama_guard: cfg = self.config.llama_guard_shield return LlamaGuardShield( model=cfg.model, inference_api=self.inference_api, excluded_categories=cfg.excluded_categories, ) - elif shield.shield_type == ShieldType.prompt_guard.value: + elif shield.shield_type == ShieldType.prompt_guard: model_dir = model_local_dir(PROMPT_GUARD_MODEL) subtype = shield.params.get("prompt_guard_type", "injection") if subtype == "injection":