From ef591b44c8193388b07aa623e56820d0e0723e36 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 3 Oct 2024 09:17:11 -0700 Subject: [PATCH] fix prompt guard --- .../impls/meta_reference/safety/config.py | 3 +-- .../impls/meta_reference/safety/safety.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/safety/config.py b/llama_stack/providers/impls/meta_reference/safety/config.py index 36428078d..1233f07b4 100644 --- a/llama_stack/providers/impls/meta_reference/safety/config.py +++ b/llama_stack/providers/impls/meta_reference/safety/config.py @@ -47,8 +47,7 @@ class LlamaGuardShieldConfig(BaseModel): return model -class PromptGuardShieldConfig(BaseModel): - model: str = "Prompt-Guard-86M" +class PromptGuardShieldConfig(BaseModel): ... class SafetyConfig(BaseModel): diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index f02574f19..5f8372e42 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -29,12 +29,16 @@ from .shields import ( ShieldBase, ) +PROMPT_GUARD_MODEL = "Prompt-Guard-86M" + def resolve_and_get_path(model_name: str) -> str: + if model_name == PROMPT_GUARD_MODEL: + return model_local_dir(model_name) + model = resolve_model(model_name) assert model is not None, f"Could not resolve model {model_name}" - model_dir = model_local_dir(model.descriptor()) - return model_dir + return model_local_dir(model.descriptor()) class MetaReferenceSafetyImpl(Safety, RoutableProvider): @@ -45,7 +49,7 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): async def initialize(self) -> None: shield_cfg = self.config.prompt_guard_shield if shield_cfg is not None: - model_dir = resolve_and_get_path(shield_cfg.model) + model_dir = resolve_and_get_path(PROMPT_GUARD_MODEL) _ = PromptGuardShield.instance(model_dir) async def shutdown(self) -> None: @@ -108,16 +112,10 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): disable_output_check=cfg.disable_output_check, ) elif typ == MetaReferenceShieldType.jailbreak_shield: - assert ( - cfg.prompt_guard_shield is not None - ), "Cannot use Jailbreak Shield since Prompt Guard not present in config" - model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model) + model_dir = resolve_and_get_path(PROMPT_GUARD_MODEL) return JailbreakShield.instance(model_dir) elif typ == MetaReferenceShieldType.injection_shield: - assert ( - cfg.prompt_guard_shield is not None - ), "Cannot use PromptGuardShield since not present in config" - model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model) + model_dir = resolve_and_get_path(PROMPT_GUARD_MODEL) return InjectionShield.instance(model_dir) elif typ == MetaReferenceShieldType.code_scanner_guard: return CodeScannerShield.instance()