diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py b/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py index 75ec7c37b..340ccb517 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from codeshield.cs import CodeShield from termcolor import cprint from .base import ShieldResponse, TextShield @@ -16,6 +15,8 @@ class CodeScannerShield(TextShield): return BuiltinShield.code_scanner_guard async def run_impl(self, text: str) -> ShieldResponse: + from codeshield.cs import CodeShield + cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta") result = await CodeShield.scan_code(text) if result.is_insecure: diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py index 67bc6a6db..acaf515b5 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py @@ -11,7 +11,6 @@ import torch from llama_models.llama3.api.datatypes import Message from termcolor import cprint -from transformers import AutoModelForSequenceClassification, AutoTokenizer from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield from llama_stack.apis.safety import * # noqa: F403 @@ -61,6 +60,8 @@ class PromptGuardShield(TextShield): raise ValueError("Temperature must be greater than 0") self.device = "cuda" if PromptGuardShield._model_cache is None: + from transformers import AutoModelForSequenceClassification, AutoTokenizer + # load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(