Introduce Llama stack distributions (#22)

* Add distribution CLI scaffolding

* More progress towards `llama distribution install`

* getting closer to a distro definition, distro install + configure works

* Distribution server now functioning

* read existing configuration, save enums properly

* Remove inference uvicorn server entrypoint and llama inference CLI command

* updated dependency and client model name

* Improved exception handling

* local imports for faster cli

* undo a typo, add a passthrough distribution

* implement full-passthrough in the server

* add safety adapters, configuration handling, server + clients

* cleanup, moving stuff to common, nuke utils

* Add a Path() wrapper at the earliest place

* fixes

* Bring agentic system api to toolchain

Add adapter dependencies and resolve adapters using a topological sort

* refactor to reduce size of `agentic_system`

* move straggler files and fix some important existing bugs

* ApiSurface -> Api

* refactor a method out

* Adapter -> Provider

* Make each inference provider into its own subdirectory

* installation fixes

* Rename Distribution -> DistributionSpec, simplify RemoteProviders

* dict key instead of attr

* update inference config to take model and not model_dir

* Fix passthrough streaming, send headers properly not part of body :facepalm

* update safety to use model sku ids and not model dirs

* Update cli_reference.md

* minor fixes

* add DistributionConfig, fix a bug in model download

* Make install + start scripts do proper configuration automatically

* Update CLI_reference

* Nuke fp8_requirements, fold fbgemm into common requirements

* Update README, add newline between API surface configurations

* Refactor download functionality out of the Command so can be reused

* Add `llama model download` alias for `llama download`

* Show message about checksum file so users can check themselves

* Simpler intro statements

* get ollama working

* Reduce a bunch of dependencies from toolchain

Some improvements to the distribution install script

* Avoid using `conda run` since it buffers everything

* update dependencies and rely on LLAMA_TOOLCHAIN_DIR for dev purposes

* add validation for configuration input

* resort imports

* make optional subclasses default to yes for configuration

* Remove additional_pip_packages; move deps to providers

* for inline make 8b model the default

* Add scripts to MANIFEST

* allow installing from test.pypi.org

* Fix #2 to help with testing packages

* Must install llama-models at that same version first

* fix PIP_ARGS

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Hardik Shah <hjshah@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-08-08 13:38:41 -07:00 committed by GitHub
parent da4645a27a
commit e830814399
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
115 changed files with 5839 additions and 1120 deletions

View file

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SafetyConfig # noqa
from .safety import get_provider_impl # noqa

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, validator
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-8B"
excluded_categories: List[str] = []
disable_input_check: bool = False
disable_output_check: bool = False
@validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if m.core_model_id == CoreModelId.llama_guard_3_8b
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class PromptGuardShieldConfig(BaseModel):
model: str = "Prompt-Guard-86M"
@validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if m.core_model_id == CoreModelId.prompt_guard_86m
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from typing import Dict
from llama_models.sku_list import resolve_model
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.distribution.datatypes import Api, ProviderSpec
from llama_toolchain.safety.api import * # noqa
from .config import SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
JailbreakShield,
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
ThirdPartyShield,
)
async def get_provider_impl(config: SafetyConfig, _deps: Dict[Api, ProviderSpec]):
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model)
return model_dir
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def run_shields(
self,
request: RunShieldRequest,
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in request.shields]
responses = await asyncio.gather(
*[shield.run(request.messages) for shield in shields]
)
return RunShieldResponse(responses=responses)
def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase:
if sc.shield_type == BuiltinShield.llama_guard:
assert (
safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif sc.shield_type == BuiltinShield.jailbreak_shield:
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.injection_shield:
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif sc.shield_type == BuiltinShield.code_scanner_guard:
return CodeScannerShield.instance()
elif sc.shield_type == BuiltinShield.third_party_shield:
return ThirdPartyShield.instance()
else:
raise ValueError(f"Unknown shield type: {sc.shield_type}")

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List, Union
from llama_models.llama3_1.api.datatypes import Attachment, Message
from llama_toolchain.safety.api.datatypes import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
def _to_str(content: Union[str, Attachment]) -> str:
if isinstance(content, str):
return content
elif isinstance(content, Attachment):
return f"File: {str(content.url)}"
else:
raise
if isinstance(message.content, list) or isinstance(message.content, tuple):
return "\n".join([_to_str(c) for c in message.content])
else:
return _to_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from codeshield.cs import CodeShield
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3_1.api.datatypes import Message
from llama_toolchain.safety.meta_reference.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError

View file

@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3_1.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_toolchain.safety.api.datatypes import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
CAT_SEX_CRIMES = "Sex Crimes"
CAT_CHILD_EXPLOITATION = "Child Exploitation"
CAT_DEFAMATION = "Defamation"
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
CAT_PRIVACY = "Privacy"
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
CAT_HATE = "Hate"
CAT_SELF_HARM = "Self-Harm"
CAT_SEXUAL_CONTENT = "Sexual Content"
CAT_ELECTIONS = "Elections"
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_VIOLENT_CRIMES: "S1",
CAT_NON_VIOLENT_CRIMES: "S2",
CAT_SEX_CRIMES: "S3",
CAT_CHILD_EXPLOITATION: "S4",
CAT_DEFAMATION: "S5",
CAT_SPECIALIZED_ADVICE: "S6",
CAT_PRIVACY: "S7",
CAT_INTELLECTUAL_PROPERTY: "S8",
CAT_INDISCRIMINATE_WEAPONS: "S9",
CAT_HATE: "S10",
CAT_SELF_HARM: "S11",
CAT_SEXUAL_CONTENT: "S12",
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES,
CAT_SEX_CRIMES,
CAT_CHILD_EXPLOITATION,
CAT_DEFAMATION,
CAT_SPECIALIZED_ADVICE,
CAT_PRIVACY,
CAT_INTELLECTUAL_PROPERTY,
CAT_INDISCRIMINATE_WEAPONS,
CAT_HATE,
CAT_SELF_HARM,
CAT_SEXUAL_CONTENT,
CAT_ELECTIONS,
CAT_CODE_INTERPRETER_ABUSE,
]
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
<BEGIN UNSAFE CONTENT CATEGORIES>
$categories
<END UNSAFE CONTENT CATEGORIES>
"""
PROMPT_CONVERSATION = """
<BEGIN CONVERSATION>
$conversations
<END CONVERSATION>
"""
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
return None
def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
categories = []
for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
continue
categories.append(f"{cat_code}: {cat}.")
return categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
)
raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response

View file

@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_toolchain.safety.api.datatypes import * # noqa: F403
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)