From d99a08c1ffd97388643f56ab657b5318daa97a8e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 3 Oct 2024 09:29:15 -0700 Subject: [PATCH] simpler --- .../impls/meta_reference/safety/safety.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index 5f8372e42..6cc7f8541 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -6,8 +6,6 @@ from typing import Any, Dict, List -from llama_models.sku_list import resolve_model - from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 @@ -32,15 +30,6 @@ from .shields import ( PROMPT_GUARD_MODEL = "Prompt-Guard-86M" -def resolve_and_get_path(model_name: str) -> str: - if model_name == PROMPT_GUARD_MODEL: - return model_local_dir(model_name) - - model = resolve_model(model_name) - assert model is not None, f"Could not resolve model {model_name}" - return model_local_dir(model.descriptor()) - - class MetaReferenceSafetyImpl(Safety, RoutableProvider): def __init__(self, config: SafetyConfig, deps) -> None: self.config = config @@ -49,7 +38,7 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): async def initialize(self) -> None: shield_cfg = self.config.prompt_guard_shield if shield_cfg is not None: - model_dir = resolve_and_get_path(PROMPT_GUARD_MODEL) + model_dir = model_local_dir(PROMPT_GUARD_MODEL) _ = PromptGuardShield.instance(model_dir) async def shutdown(self) -> None: @@ -112,10 +101,10 @@ class MetaReferenceSafetyImpl(Safety, RoutableProvider): disable_output_check=cfg.disable_output_check, ) elif typ == MetaReferenceShieldType.jailbreak_shield: - model_dir = resolve_and_get_path(PROMPT_GUARD_MODEL) + model_dir = model_local_dir(PROMPT_GUARD_MODEL) return JailbreakShield.instance(model_dir) elif typ == MetaReferenceShieldType.injection_shield: - model_dir = resolve_and_get_path(PROMPT_GUARD_MODEL) + model_dir = model_local_dir(PROMPT_GUARD_MODEL) return InjectionShield.instance(model_dir) elif typ == MetaReferenceShieldType.code_scanner_guard: return CodeScannerShield.instance()