llama-stack-mirror/llama_toolchain/safety/meta_reference/safety.py
2024-08-18 19:13:15 -07:00

106 lines
3.9 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Dict
from llama_models.sku_list import resolve_model
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.safety.api import * # noqa
from .config import SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
JailbreakShield,
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
ThirdPartyShield,
)
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model.descriptor())
return model_dir
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def run_shields(
self,
request: RunShieldRequest,
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
responses = await asyncio.gather(
*[shield.run(request.messages) for shield in shields]
)
return RunShieldResponse(responses=responses)
def shield_type_equals(a: ShieldType, b: ShieldType):
return a == b or a == b.value
def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase:
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
assert (
safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
return CodeScannerShield.instance()
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
return ThirdPartyShield.instance()
else:
raise ValueError(f"Unknown shield type: {sc.shield_type}")