mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
make shield imports more lazy
This commit is contained in:
parent
81ff7476d3
commit
9fd431e710
2 changed files with 4 additions and 2 deletions
|
@ -4,7 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from codeshield.cs import CodeShield
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .base import ShieldResponse, TextShield
|
from .base import ShieldResponse, TextShield
|
||||||
|
@ -16,6 +15,8 @@ class CodeScannerShield(TextShield):
|
||||||
return BuiltinShield.code_scanner_guard
|
return BuiltinShield.code_scanner_guard
|
||||||
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
|
||||||
result = await CodeShield.scan_code(text)
|
result = await CodeShield.scan_code(text)
|
||||||
if result.is_insecure:
|
if result.is_insecure:
|
||||||
|
|
|
@ -11,7 +11,6 @@ import torch
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
||||||
|
|
||||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
@ -61,6 +60,8 @@ class PromptGuardShield(TextShield):
|
||||||
raise ValueError("Temperature must be greater than 0")
|
raise ValueError("Temperature must be greater than 0")
|
||||||
self.device = "cuda"
|
self.device = "cuda"
|
||||||
if PromptGuardShield._model_cache is None:
|
if PromptGuardShield._model_cache is None:
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
# load model and tokenizer
|
# load model and tokenizer
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue