diff --git a/llama_toolchain/safety/meta_reference/config.py b/llama_toolchain/safety/meta_reference/config.py index 7cb32b3c1..7369ec539 100644 --- a/llama_toolchain/safety/meta_reference/config.py +++ b/llama_toolchain/safety/meta_reference/config.py @@ -10,14 +10,14 @@ from pydantic import BaseModel class LlamaGuardShieldConfig(BaseModel): - model_dir: str + model: str excluded_categories: List[str] disable_input_check: bool = False disable_output_check: bool = False class PromptGuardShieldConfig(BaseModel): - model_dir: str + model: str class SafetyConfig(BaseModel): diff --git a/llama_toolchain/safety/meta_reference/safety.py b/llama_toolchain/safety/meta_reference/safety.py index 93b986c12..60d16dbf1 100644 --- a/llama_toolchain/safety/meta_reference/safety.py +++ b/llama_toolchain/safety/meta_reference/safety.py @@ -5,13 +5,15 @@ # the root directory of this source tree. import asyncio - from typing import Dict +from llama_models.sku_list import resolve_model + +from llama_toolchain.common.model_utils import model_local_dir from llama_toolchain.distribution.datatypes import Api, ProviderSpec +from llama_toolchain.safety.api import * # noqa from .config import SafetyConfig -from llama_toolchain.safety.api import * # noqa from .shields import ( CodeScannerShield, InjectionShield, @@ -31,6 +33,13 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec] return impl +def resolve_and_get_path(model_name: str) -> str: + model = resolve_model(model_name) + assert model is not None, f"Could not resolve model {model_name}" + model_dir = model_local_dir(model) + return model_dir + + class MetaReferenceSafetyImpl(Safety): def __init__(self, config: SafetyConfig) -> None: @@ -39,8 +48,9 @@ class MetaReferenceSafetyImpl(Safety): async def initialize(self) -> None: shield_cfg = self.config.llama_guard_shield if shield_cfg is not None: + model_dir = resolve_and_get_path(shield_cfg.model) _ = LlamaGuardShield.instance( - model_dir=shield_cfg.model_dir, + model_dir=model_dir, excluded_categories=shield_cfg.excluded_categories, disable_input_check=shield_cfg.disable_input_check, disable_output_check=shield_cfg.disable_output_check, @@ -48,7 +58,8 @@ class MetaReferenceSafetyImpl(Safety): shield_cfg = self.config.prompt_guard_shield if shield_cfg is not None: - _ = PromptGuardShield.instance(shield_cfg.model_dir) + model_dir = resolve_and_get_path(shield_cfg.model) + _ = PromptGuardShield.instance(model_dir) async def run_shields( self, @@ -70,19 +81,20 @@ def shield_config_to_shield( assert ( safety_config.llama_guard_shield is not None ), "Cannot use LlamaGuardShield since not present in config" - return LlamaGuardShield.instance( - model_dir=safety_config.llama_guard_shield.model_dir - ) + model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model) + return LlamaGuardShield.instance(model_dir=model_dir) elif sc.shield_type == BuiltinShield.jailbreak_shield: assert ( safety_config.prompt_guard_shield is not None ), "Cannot use Jailbreak Shield since Prompt Guard not present in config" - return JailbreakShield.instance(safety_config.prompt_guard_shield.model_dir) + model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) + return JailbreakShield.instance(model_dir) elif sc.shield_type == BuiltinShield.injection_shield: assert ( safety_config.prompt_guard_shield is not None ), "Cannot use PromptGuardShield since not present in config" - return InjectionShield.instance(safety_config.prompt_guard_shield.model_dir) + model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model) + return InjectionShield.instance(model_dir) elif sc.shield_type == BuiltinShield.code_scanner_guard: return CodeScannerShield.instance() elif sc.shield_type == BuiltinShield.third_party_shield: