Initial commit

This commit is contained in:
Ashwin Bharambe 2024-06-25 15:47:57 -07:00 committed by Ashwin Bharambe
commit 5d5acc8ed5
81 changed files with 4458 additions and 0 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from pydantic import BaseModel
class LlamaGuardShieldConfig(BaseModel):
model_dir: str
excluded_categories: List[str]
disable_input_check: bool = False
disable_output_check: bool = False
class PromptGuardShieldConfig(BaseModel):
model_dir: str
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Dict, Optional, Union
from llama_models.llama3_1.api.datatypes import ToolParamDefinition
from pydantic import BaseModel
from strong_typing.schema import json_schema_type
from llama_toolchain.common.deployment_types import RestAPIExecutionConfig
@json_schema_type
class BuiltinShield(Enum):
llama_guard = "llama_guard"
code_scanner_guard = "code_scanner_guard"
third_party_shield = "third_party_shield"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
ShieldType = Union[BuiltinShield, str]
@json_schema_type
class OnViolationAction(Enum):
IGNORE = 0
WARN = 1
RAISE = 2
@json_schema_type
class ShieldDefinition(BaseModel):
shield_type: ShieldType
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
on_violation_action: OnViolationAction = OnViolationAction.RAISE
execution_config: Optional[RestAPIExecutionConfig] = None
@json_schema_type
class ShieldCall(BaseModel):
call_id: str
shield_type: ShieldType
arguments: Dict[str, str]
@json_schema_type
class ShieldResponse(BaseModel):
shield_type: ShieldType
# TODO(ashwin): clean this up
is_violation: bool
violation_type: Optional[str] = None
violation_return_message: Optional[str] = None

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List, Union
from llama_models.llama3_1.api.datatypes import Attachment, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
def _to_str(content: Union[str, Attachment]) -> str:
if isinstance(content, str):
return content
elif isinstance(content, Attachment):
return f"File: {str(content.url)}"
else:
raise
if isinstance(message.content, list) or isinstance(message.content, tuple):
return "\n".join([_to_str(c) for c in message.content])
else:
return _to_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from codeshield.cs import CodeShield
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,38 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import sys
from typing import List
from llama_models.llama3_1.api.datatypes import Message
parent_dir = "../.."
sys.path.append(parent_dir)
from llama_toolchain.safety.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError

View file

@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3_1.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_toolchain.safety.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
CAT_SEX_CRIMES = "Sex Crimes"
CAT_CHILD_EXPLOITATION = "Child Exploitation"
CAT_DEFAMATION = "Defamation"
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
CAT_PRIVACY = "Privacy"
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
CAT_HATE = "Hate"
CAT_SELF_HARM = "Self-Harm"
CAT_SEXUAL_CONTENT = "Sexual Content"
CAT_ELECTIONS = "Elections"
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_VIOLENT_CRIMES: "S1",
CAT_NON_VIOLENT_CRIMES: "S2",
CAT_SEX_CRIMES: "S3",
CAT_CHILD_EXPLOITATION: "S4",
CAT_DEFAMATION: "S5",
CAT_SPECIALIZED_ADVICE: "S6",
CAT_PRIVACY: "S7",
CAT_INTELLECTUAL_PROPERTY: "S8",
CAT_INDISCRIMINATE_WEAPONS: "S9",
CAT_HATE: "S10",
CAT_SELF_HARM: "S11",
CAT_SEXUAL_CONTENT: "S12",
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES,
CAT_SEX_CRIMES,
CAT_CHILD_EXPLOITATION,
CAT_DEFAMATION,
CAT_SPECIALIZED_ADVICE,
CAT_PRIVACY,
CAT_INTELLECTUAL_PROPERTY,
CAT_INDISCRIMINATE_WEAPONS,
CAT_HATE,
CAT_SELF_HARM,
CAT_SEXUAL_CONTENT,
CAT_ELECTIONS,
CAT_CODE_INTERPRETER_ABUSE,
]
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
<BEGIN UNSAFE CONTENT CATEGORIES>
$categories
<END UNSAFE CONTENT CATEGORIES>
"""
PROMPT_CONVERSATION = """
<BEGIN CONVERSATION>
$conversations
<END CONVERSATION>
"""
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
return None
def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
categories = []
for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
continue
categories.append(f"{cat_code}: {cat}.")
return categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
)
raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response

View file

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import List
from llama_models.llama3_1.api.datatypes import Message, Role
from .base import OnViolationAction, ShieldBase, ShieldResponse
class SafetyException(Exception): # noqa: N818
def __init__(self, response: ShieldResponse):
self.response = response
super().__init__(response.violation_return_message)
class ShieldRunnerMixin:
def __init__(
self,
input_shields: List[ShieldBase] = None,
output_shields: List[ShieldBase] = None,
):
self.input_shields = input_shields
self.output_shields = output_shields
async def run_shields(
self, messages: List[Message], shields: List[ShieldBase]
) -> List[ShieldResponse]:
# some shields like llama-guard require the first message to be a user message
# since this might be a tool call, first role might not be user
if len(messages) > 0 and messages[0].role != Role.user.value:
# TODO(ashwin): we need to change the type of the message, this kind of modification
# is no longer appropriate
messages[0].role = Role.user.value
results = await asyncio.gather(*[s.run(messages) for s in shields])
for shield, r in zip(shields, results):
if r.is_violation:
if shield.on_violation_action == OnViolationAction.RAISE:
raise SafetyException(r)
elif shield.on_violation_action == OnViolationAction.WARN:
cprint(
f"[Warn]{shield.__class__.__name__} raised a warning",
color="red",
)
return results