make shield imports more lazy

This commit is contained in:
Ashwin Bharambe 2024-09-17 21:27:16 -07:00
parent 81ff7476d3
commit 9fd431e710
2 changed files with 4 additions and 2 deletions

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from codeshield.cs import CodeShield
from termcolor import cprint
from .base import ShieldResponse, TextShield
@ -16,6 +15,8 @@ class CodeScannerShield(TextShield):
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:

View file

@ -11,7 +11,6 @@ import torch
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
@ -61,6 +60,8 @@ class PromptGuardShield(TextShield):
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(