Fix ShieldType Union equality bug

This commit is contained in:
dltn 2024-08-18 19:13:15 -07:00
parent 5e072d0780
commit f502716cf7

View file

@ -73,30 +73,34 @@ class MetaReferenceSafetyImpl(Safety):
return RunShieldResponse(responses=responses) return RunShieldResponse(responses=responses)
def shield_type_equals(a: ShieldType, b: ShieldType):
return a == b or a == b.value
def shield_config_to_shield( def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase: ) -> ShieldBase:
if sc.shield_type == BuiltinShield.llama_guard: if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
assert ( assert (
safety_config.llama_guard_shield is not None safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config" ), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model) model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir) return LlamaGuardShield.instance(model_dir=model_dir)
elif sc.shield_type == BuiltinShield.jailbreak_shield: elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
assert ( assert (
safety_config.prompt_guard_shield is not None safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config" ), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir) return JailbreakShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.injection_shield: elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
assert ( assert (
safety_config.prompt_guard_shield is not None safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config" ), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir) return InjectionShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.code_scanner_guard: elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
return CodeScannerShield.instance() return CodeScannerShield.instance()
elif sc.shield_type == BuiltinShield.third_party_shield: elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
return ThirdPartyShield.instance() return ThirdPartyShield.instance()
else: else:
raise ValueError(f"Unknown shield type: {sc.shield_type}") raise ValueError(f"Unknown shield type: {sc.shield_type}")