forked from phoenix-oss/llama-stack-mirror
156 lines
5.4 KiB
Python
156 lines
5.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from enum import auto, Enum
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
from llama_models.llama3_1.api.datatypes import Message
|
|
from termcolor import cprint
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
|
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
|
|
|
|
|
class PromptGuardShield(TextShield):
|
|
|
|
class Mode(Enum):
|
|
INJECTION = auto()
|
|
JAILBREAK = auto()
|
|
|
|
_instances = {}
|
|
_model_cache = None
|
|
|
|
@staticmethod
|
|
def instance(
|
|
model_dir: str,
|
|
threshold: float = 0.9,
|
|
temperature: float = 1.0,
|
|
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
|
on_violation_action=OnViolationAction.RAISE,
|
|
) -> "PromptGuardShield":
|
|
action_value = on_violation_action.value
|
|
key = (model_dir, threshold, temperature, mode, action_value)
|
|
if key not in PromptGuardShield._instances:
|
|
PromptGuardShield._instances[key] = PromptGuardShield(
|
|
model_dir=model_dir,
|
|
threshold=threshold,
|
|
temperature=temperature,
|
|
mode=mode,
|
|
on_violation_action=on_violation_action,
|
|
)
|
|
return PromptGuardShield._instances[key]
|
|
|
|
def __init__(
|
|
self,
|
|
model_dir: str,
|
|
threshold: float = 0.9,
|
|
temperature: float = 1.0,
|
|
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
):
|
|
super().__init__(on_violation_action)
|
|
assert (
|
|
model_dir is not None
|
|
), "Must provide a model directory for prompt injection shield"
|
|
if temperature <= 0:
|
|
raise ValueError("Temperature must be greater than 0")
|
|
self.device = "cuda"
|
|
if PromptGuardShield._model_cache is None:
|
|
# load model and tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
|
model = AutoModelForSequenceClassification.from_pretrained(
|
|
model_dir, device_map=self.device
|
|
)
|
|
PromptGuardShield._model_cache = (tokenizer, model)
|
|
|
|
self.tokenizer, self.model = PromptGuardShield._model_cache
|
|
self.temperature = temperature
|
|
self.threshold = threshold
|
|
self.mode = mode
|
|
|
|
def get_shield_type(self) -> ShieldType:
|
|
return (
|
|
BuiltinShield.jailbreak_shield
|
|
if self.mode == self.Mode.JAILBREAK
|
|
else BuiltinShield.injection_shield
|
|
)
|
|
|
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
|
return message_content_as_str(messages[-1])
|
|
|
|
async def run_impl(self, text: str) -> ShieldResponse:
|
|
# run model on messages and return response
|
|
inputs = self.tokenizer(text, return_tensors="pt")
|
|
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
|
with torch.no_grad():
|
|
outputs = self.model(**inputs)
|
|
logits = outputs[0]
|
|
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
|
score_embedded = probabilities[0, 1].item()
|
|
score_malicious = probabilities[0, 2].item()
|
|
cprint(
|
|
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
|
color="magenta",
|
|
)
|
|
|
|
if self.mode == self.Mode.INJECTION and (
|
|
score_embedded + score_malicious > self.threshold
|
|
):
|
|
return ShieldResponse(
|
|
shield_type=self.get_shield_type(),
|
|
is_violation=True,
|
|
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
|
violation_return_message="Sorry, I cannot do this.",
|
|
)
|
|
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
|
return ShieldResponse(
|
|
shield_type=self.get_shield_type(),
|
|
is_violation=True,
|
|
violation_type=f"prompt_injection:malicious={score_malicious}",
|
|
violation_return_message="Sorry, I cannot do this.",
|
|
)
|
|
|
|
return ShieldResponse(
|
|
shield_type=self.get_shield_type(),
|
|
is_violation=False,
|
|
)
|
|
|
|
|
|
class JailbreakShield(PromptGuardShield):
|
|
def __init__(
|
|
self,
|
|
model_dir: str,
|
|
threshold: float = 0.9,
|
|
temperature: float = 1.0,
|
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
):
|
|
super().__init__(
|
|
model_dir=model_dir,
|
|
threshold=threshold,
|
|
temperature=temperature,
|
|
mode=PromptGuardShield.Mode.JAILBREAK,
|
|
on_violation_action=on_violation_action,
|
|
)
|
|
|
|
|
|
class InjectionShield(PromptGuardShield):
|
|
def __init__(
|
|
self,
|
|
model_dir: str,
|
|
threshold: float = 0.9,
|
|
temperature: float = 1.0,
|
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
|
):
|
|
super().__init__(
|
|
model_dir=model_dir,
|
|
threshold=threshold,
|
|
temperature=temperature,
|
|
mode=PromptGuardShield.Mode.INJECTION,
|
|
on_violation_action=on_violation_action,
|
|
)
|