mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
Fix ShieldType Union equality bug
This commit is contained in:
parent
5e072d0780
commit
f502716cf7
1 changed files with 9 additions and 5 deletions
|
@ -73,30 +73,34 @@ class MetaReferenceSafetyImpl(Safety):
|
||||||
return RunShieldResponse(responses=responses)
|
return RunShieldResponse(responses=responses)
|
||||||
|
|
||||||
|
|
||||||
|
def shield_type_equals(a: ShieldType, b: ShieldType):
|
||||||
|
return a == b or a == b.value
|
||||||
|
|
||||||
|
|
||||||
def shield_config_to_shield(
|
def shield_config_to_shield(
|
||||||
sc: ShieldDefinition, safety_config: SafetyConfig
|
sc: ShieldDefinition, safety_config: SafetyConfig
|
||||||
) -> ShieldBase:
|
) -> ShieldBase:
|
||||||
if sc.shield_type == BuiltinShield.llama_guard:
|
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
|
||||||
assert (
|
assert (
|
||||||
safety_config.llama_guard_shield is not None
|
safety_config.llama_guard_shield is not None
|
||||||
), "Cannot use LlamaGuardShield since not present in config"
|
), "Cannot use LlamaGuardShield since not present in config"
|
||||||
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
|
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
|
||||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||||
elif sc.shield_type == BuiltinShield.jailbreak_shield:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
|
||||||
assert (
|
assert (
|
||||||
safety_config.prompt_guard_shield is not None
|
safety_config.prompt_guard_shield is not None
|
||||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||||
return JailbreakShield.instance(model_dir)
|
return JailbreakShield.instance(model_dir)
|
||||||
elif sc.shield_type == BuiltinShield.injection_shield:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
|
||||||
assert (
|
assert (
|
||||||
safety_config.prompt_guard_shield is not None
|
safety_config.prompt_guard_shield is not None
|
||||||
), "Cannot use PromptGuardShield since not present in config"
|
), "Cannot use PromptGuardShield since not present in config"
|
||||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||||
return InjectionShield.instance(model_dir)
|
return InjectionShield.instance(model_dir)
|
||||||
elif sc.shield_type == BuiltinShield.code_scanner_guard:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
|
||||||
return CodeScannerShield.instance()
|
return CodeScannerShield.instance()
|
||||||
elif sc.shield_type == BuiltinShield.third_party_shield:
|
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
|
||||||
return ThirdPartyShield.instance()
|
return ThirdPartyShield.instance()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown shield type: {sc.shield_type}")
|
raise ValueError(f"Unknown shield type: {sc.shield_type}")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue