forked from phoenix-oss/llama-stack-mirror
Revert parts of 0d2eb3bd25
This commit is contained in:
parent
059e50b389
commit
a2465f3f9c
4 changed files with 75 additions and 43 deletions
|
@ -7,10 +7,8 @@
|
|||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.safety.shields.base import (
|
||||
OnViolationAction,
|
||||
|
@ -36,11 +34,20 @@ def resolve_and_get_path(model_name: str) -> str:
|
|||
|
||||
|
||||
class MetaReferenceSafetyImpl(Safety):
|
||||
def __init__(self, config: SafetyConfig, deps) -> None:
|
||||
def __init__(self, config: SafetyConfig) -> None:
|
||||
self.config = config
|
||||
self.inference_api = deps[Api.inference]
|
||||
|
||||
async def initialize(self) -> None:
|
||||
shield_cfg = self.config.llama_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
_ = LlamaGuardShield.instance(
|
||||
model_dir=model_dir,
|
||||
excluded_categories=shield_cfg.excluded_categories,
|
||||
disable_input_check=shield_cfg.disable_input_check,
|
||||
disable_output_check=shield_cfg.disable_output_check,
|
||||
)
|
||||
|
||||
shield_cfg = self.config.prompt_guard_shield
|
||||
if shield_cfg is not None:
|
||||
model_dir = resolve_and_get_path(shield_cfg.model)
|
||||
|
@ -84,18 +91,11 @@ class MetaReferenceSafetyImpl(Safety):
|
|||
def get_shield_impl(self, typ: MetaReferenceShieldType) -> ShieldBase:
|
||||
cfg = self.config
|
||||
if typ == MetaReferenceShieldType.llama_guard:
|
||||
cfg = cfg.llama_guard_shield
|
||||
assert (
|
||||
cfg is not None
|
||||
cfg.llama_guard_shield is not None
|
||||
), "Cannot use LlamaGuardShield since not present in config"
|
||||
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
disable_input_check=cfg.disable_input_check,
|
||||
disable_output_check=cfg.disable_output_check,
|
||||
)
|
||||
model_dir = resolve_and_get_path(cfg.llama_guard_shield.model)
|
||||
return LlamaGuardShield.instance(model_dir=model_dir)
|
||||
elif typ == MetaReferenceShieldType.jailbreak_shield:
|
||||
assert (
|
||||
cfg.prompt_guard_shield is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue