Update the meta reference safety implementation to match new API

This commit is contained in:
Ashwin Bharambe 2024-09-20 14:17:44 -07:00 committed by Xi Yan
parent 7e40eead4e
commit 82ddd851c8
11 changed files with 115 additions and 130 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from .config import MetaReferenceShieldType, SafetyConfig
from .config import SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
@ -19,10 +19,16 @@ from .shields import (
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
ThirdPartyShield,
)
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model.descriptor())
return model_dir
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
self.config = config
@ -45,45 +51,56 @@ class MetaReferenceSafetyImpl(Safety):
async def run_shield(
self,
shield_type: ShieldType,
shield_type: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse:
assert shield_type in [
"llama_guard",
"prompt_guard",
], f"Unknown shield {shield_type}"
available_shields = [v.value for v in MetaReferenceShieldType]
assert shield_type in available_shields, f"Unknown shield {shield_type}"
raise NotImplementedError()
shield = self.get_shield_impl(MetaReferenceShieldType(shield_type))
messages = messages.copy()
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content)
def shield_type_equals(a: ShieldType, b: ShieldType):
return a == b or a == b.value
# TODO: we can refactor ShieldBase, etc. to be inline with the API types
res = await shield.run(messages)
violation = None
if res.is_violation:
violation = SafetyViolation(
violation_level=ViolationLevel.ERROR,
user_message=res.violation_return_message,
metadata={
"violation_type": res.violation_type,
},
)
return RunShieldResponse(violation=violation)
def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase:
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
assert (
safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
return CodeScannerShield.instance()
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
return ThirdPartyShield.instance()
else:
raise ValueError(f"Unknown shield type: {sc.shield_type}")
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
cfg = self.config
if typ == MetaReferenceShieldType.llama_guard:
assert (
cfg.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif typ == MetaReferenceShieldType.jailbreak_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif typ == MetaReferenceShieldType.injection_shield:
assert (
cfg.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(cfg.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif typ == MetaReferenceShieldType.code_scanner_guard:
return CodeScannerShield.instance()
else:
raise ValueError(f"Unknown shield type: {typ}")