mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
update safety to use model sku ids and not model dirs
This commit is contained in:
parent
a0e61a3c7a
commit
2a9bdb208b
2 changed files with 23 additions and 11 deletions
|
@ -10,14 +10,14 @@ from pydantic import BaseModel
|
|||
|
||||
|
||||
class LlamaGuardShieldConfig(BaseModel):
|
||||
model_dir: str
|
||||
model: str
|
||||
excluded_categories: List[str]
|
||||
disable_input_check: bool = False
|
||||
disable_output_check: bool = False
|
||||
|
||||
|
||||
class PromptGuardShieldConfig(BaseModel):
|
||||
model_dir: str
|
||||
model: str
|
||||
|
||||
|
||||
class SafetyConfig(BaseModel):
|
||||
|
|
|
@ -5,13 +5,15 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
|
||||
from llama_toolchain.safety.api import * # noqa
|
||||
|
||||
from .config import SafetyConfig
|
||||
from llama_toolchain.safety.api import * # noqa
|
||||
from .shields import (
|
||||
CodeScannerShield,
|
||||
InjectionShield,
|
||||
|
@ -31,6 +33,13 @@ async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]
|
|||
return impl
|
||||
|
||||
|
||||
def resolve_and_get_path(model_name: str) -> str:
|
||||
model = resolve_model(model_name)
|
||||
assert model is not None, f"Could not resolve model {model_name}"
|
||||
model_dir = model_local_dir(model)
|
||||
return model_dir
|
||||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
|
||||
def __init__(self, config: SafetyConfig) -> None:
|
||||
|
@ -39,8 +48,9 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
async def initialize(self) -> None:
|
||||
shield_cfg = self.config.llama_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
_ = LlamaGuardShield.instance(
|
||||
model_dir=shield_cfg.model_dir,
|
||||
model_dir=model_dir,
|
||||
excluded_categories=shield_cfg.excluded_categories,
|
||||
disable_input_check=shield_cfg.disable_input_check,
|
||||
disable_output_check=shield_cfg.disable_output_check,
|
||||
|
@ -48,7 +58,8 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
|
||||
shield_cfg = self.config.prompt_guard_shield
|
||||
if shield_cfg is not None:
|
||||
_ = PromptGuardShield.instance(shield_cfg.model_dir)
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
_ = PromptGuardShield.instance(model_dir)
|
||||
|
||||
async def run_shields(
|
||||
self,
|
||||
|
@ -70,19 +81,20 @@ def shield_config_to_shield(
|
|||
assert (
|
||||
safety_config.llama_guard_shield is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
return LlamaGuardShield.instance(
|
||||
model_dir=safety_config.llama_guard_shield.model_dir
|
||||
)
|
||||
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
|
||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||
elif sc.shield_type == BuiltinShield.jailbreak_shield:
|
||||
assert (
|
||||
safety_config.prompt_guard_shield is not None
|
||||
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
|
||||
return JailbreakShield.instance(safety_config.prompt_guard_shield.model_dir)
|
||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||
return JailbreakShield.instance(model_dir)
|
||||
elif sc.shield_type == BuiltinShield.injection_shield:
|
||||
assert (
|
||||
safety_config.prompt_guard_shield is not None
|
||||
), "Cannot use PromptGuardShield since not present in config"
|
||||
return InjectionShield.instance(safety_config.prompt_guard_shield.model_dir)
|
||||
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
|
||||
return InjectionShield.instance(model_dir)
|
||||
elif sc.shield_type == BuiltinShield.code_scanner_guard:
|
||||
return CodeScannerShield.instance()
|
||||
elif sc.shield_type == BuiltinShield.third_party_shield:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue