forked from phoenix-oss/llama-stack-mirror
Initial commit
This commit is contained in:
commit
5d5acc8ed5
81 changed files with 4458 additions and 0 deletions
156
llama_toolchain/safety/shields/prompt_guard.py
Normal file
156
llama_toolchain/safety/shields/prompt_guard.py
Normal file
|
@ -0,0 +1,156 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import auto, Enum
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||
from llama_toolchain.safety.api.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
class PromptGuardShield(TextShield):
|
||||
|
||||
class Mode(Enum):
|
||||
INJECTION = auto()
|
||||
JAILBREAK = auto()
|
||||
|
||||
_instances = {}
|
||||
_model_cache = None
|
||||
|
||||
@staticmethod
|
||||
def instance(
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
) -> "PromptGuardShield":
|
||||
action_value = on_violation_action.value
|
||||
key = (model_dir, threshold, temperature, mode, action_value)
|
||||
if key not in PromptGuardShield._instances:
|
||||
PromptGuardShield._instances[key] = PromptGuardShield(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=mode,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
return PromptGuardShield._instances[key]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(on_violation_action)
|
||||
assert (
|
||||
model_dir is not None
|
||||
), "Must provide a model directory for prompt injection shield"
|
||||
if temperature <= 0:
|
||||
raise ValueError("Temperature must be greater than 0")
|
||||
self.device = "cuda"
|
||||
if PromptGuardShield._model_cache is None:
|
||||
# load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_dir, device_map=self.device
|
||||
)
|
||||
PromptGuardShield._model_cache = (tokenizer, model)
|
||||
|
||||
self.tokenizer, self.model = PromptGuardShield._model_cache
|
||||
self.temperature = temperature
|
||||
self.threshold = threshold
|
||||
self.mode = mode
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return (
|
||||
BuiltinShield.jailbreak_shield
|
||||
if self.mode == self.Mode.JAILBREAK
|
||||
else BuiltinShield.injection_shield
|
||||
)
|
||||
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return message_content_as_str(messages[-1])
|
||||
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# run model on messages and return response
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
|
||||
with torch.no_grad():
|
||||
outputs = self.model(**inputs)
|
||||
logits = outputs[0]
|
||||
probabilities = torch.softmax(logits / self.temperature, dim=-1)
|
||||
score_embedded = probabilities[0, 1].item()
|
||||
score_malicious = probabilities[0, 2].item()
|
||||
cprint(
|
||||
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
|
||||
color="magenta",
|
||||
)
|
||||
|
||||
if self.mode == self.Mode.INJECTION and (
|
||||
score_embedded + score_malicious > self.threshold
|
||||
):
|
||||
return ShieldResponse(
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||
return ShieldResponse(
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
|
||||
return ShieldResponse(
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
|
||||
class JailbreakShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.JAILBREAK,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
|
||||
|
||||
class InjectionShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.INJECTION,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
Loading…
Add table
Add a link
Reference in a new issue