update safety to use model sku ids and not model dirs

This commit is contained in:
Hardik Shah 2024-08-06 17:10:01 -07:00
parent a0e61a3c7a
commit 2a9bdb208b
2 changed files with 23 additions and 11 deletions

View file

@ -10,14 +10,14 @@ from pydantic import BaseModel
class LlamaGuardShieldConfig(BaseModel): class LlamaGuardShieldConfig(BaseModel):
model_dir: str model: str
excluded_categories: List[str] excluded_categories: List[str]
disable_input_check: bool = False disable_input_check: bool = False
disable_output_check: bool = False disable_output_check: bool = False
class PromptGuardShieldConfig(BaseModel): class PromptGuardShieldConfig(BaseModel):
model_dir: str model: str
class SafetyConfig(BaseModel): class SafetyConfig(BaseModel):

View file

@ -5,13 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
from typing import Dict from typing import Dict
from llama_models.sku_list import resolve_model
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.distribution.datatypes import Api, ProviderSpec from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.safety.api import * # noqa
from .config import SafetyConfig from .config import SafetyConfig
from llama_toolchain.safety.api import * # noqa
from .shields import ( from .shields import (
CodeScannerShield, CodeScannerShield,
InjectionShield, InjectionShield,
@ -31,6 +33,13 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
return impl return impl
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model)
return model_dir
class MetaReferenceSafetyImpl(Safety): class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None: def __init__(self, config: SafetyConfig) -> None:
@ -39,8 +48,9 @@ class MetaReferenceSafetyImpl(Safety):
async def initialize(self) -> None: async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None: if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance( _ = LlamaGuardShield.instance(
model_dir=shield_cfg.model_dir, model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories, excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check, disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check, disable_output_check=shield_cfg.disable_output_check,
@ -48,7 +58,8 @@ class MetaReferenceSafetyImpl(Safety):
shield_cfg = self.config.prompt_guard_shield shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None: if shield_cfg is not None:
_ = PromptGuardShield.instance(shield_cfg.model_dir) model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def run_shields( async def run_shields(
self, self,
@ -70,19 +81,20 @@ def shield_config_to_shield(
assert ( assert (
safety_config.llama_guard_shield is not None safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config" ), "Cannot use LlamaGuardShield since not present in config"
return LlamaGuardShield.instance( model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
model_dir=safety_config.llama_guard_shield.model_dir return LlamaGuardShield.instance(model_dir=model_dir)
)
elif sc.shield_type == BuiltinShield.jailbreak_shield: elif sc.shield_type == BuiltinShield.jailbreak_shield:
assert ( assert (
safety_config.prompt_guard_shield is not None safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config" ), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
return JailbreakShield.instance(safety_config.prompt_guard_shield.model_dir) model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.injection_shield: elif sc.shield_type == BuiltinShield.injection_shield:
assert ( assert (
safety_config.prompt_guard_shield is not None safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config" ), "Cannot use PromptGuardShield since not present in config"
return InjectionShield.instance(safety_config.prompt_guard_shield.model_dir) model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.code_scanner_guard: elif sc.shield_type == BuiltinShield.code_scanner_guard:
return CodeScannerShield.instance() return CodeScannerShield.instance()
elif sc.shield_type == BuiltinShield.third_party_shield: elif sc.shield_type == BuiltinShield.third_party_shield: