diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index c669eed2f..8f63b14f2 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -73,30 +73,34 @@ class MetaReferenceSafetyImpl(Safety): return RunShieldResponse(responses=responses) +def shield_type_equals(a: ShieldType, b: ShieldType): + return a == b or a == b.value + + def shield_config_to_shield( sc: ShieldDefinition, safety_config: SafetyConfig ) -> ShieldBase: - if sc.shield_type == BuiltinShield.llama_guard: + if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard): assert ( safety_config.llama_guard_shield is not None ), "Cannot use LlamaGuardShield since not present in config" model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model) return LlamaGuardShield.instance(model_dir=model_dir) - elif sc.shield_type == BuiltinShield.jailbreak_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield): assert ( safety_config.prompt_guard_shield is not None ), "Cannot use Jailbreak Shield since Prompt Guard not present in config" model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) return JailbreakShield.instance(model_dir) - elif sc.shield_type == BuiltinShield.injection_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield): assert ( safety_config.prompt_guard_shield is not None ), "Cannot use PromptGuardShield since not present in config" model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) return InjectionShield.instance(model_dir) - elif sc.shield_type == BuiltinShield.code_scanner_guard: + elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard): return CodeScannerShield.instance() - elif sc.shield_type == BuiltinShield.third_party_shield: + elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield): return ThirdPartyShield.instance() else: raise ValueError(f"Unknown shield type: {sc.shield_type}")