From f502716cf73e80c09da4f4eb4befa11b00b26e46 Mon Sep 17 00:00:00 2001 From: dltn <6599399+dltn@users.noreply.github.com> Date: Sun, 18 Aug 2024 19:13:15 -0700 Subject: [PATCH] Fix ShieldType Union equality bug --- llama_toolchain/safety/meta_reference/safety.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index c669eed2f..8f63b14f2 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -73,30 +73,34 @@ class MetaReferenceSafetyImpl(Safety): return RunShieldResponse(responses=responses) +def shield_type_equals(a: ShieldType, b: ShieldType): + return a == b or a == b.value + + def shield_config_to_shield( sc: ShieldDefinition, safety_config: SafetyConfig ) -> ShieldBase: - if sc.shield_type == BuiltinShield.llama_guard: + if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard): assert ( safety_config.llama_guard_shield is not None ), "Cannot use LlamaGuardShield since not present in config" model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model) return LlamaGuardShield.instance(model_dir=model_dir) - elif sc.shield_type == BuiltinShield.jailbreak_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield): assert ( safety_config.prompt_guard_shield is not None ), "Cannot use Jailbreak Shield since Prompt Guard not present in config" model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) return JailbreakShield.instance(model_dir) - elif sc.shield_type == BuiltinShield.injection_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield): assert ( safety_config.prompt_guard_shield is not None ), "Cannot use PromptGuardShield since not present in config" model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) return InjectionShield.instance(model_dir) - elif sc.shield_type == BuiltinShield.code_scanner_guard: + elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard): return CodeScannerShield.instance() - elif sc.shield_type == BuiltinShield.third_party_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield): return ThirdPartyShield.instance() else: raise ValueError(f"Unknown shield type: {sc.shield_type}")