API Updates (#73)

* API Keys passed from Client instead of distro configuration

* delete distribution registry

* Rename the "package" word away

* Introduce a "Router" layer for providers

Some providers need to be factorized and considered as thin routing
layers on top of other providers. Consider two examples:

- The inference API should be a routing layer over inference providers,
  routed using the "model" key
- The memory banks API is another instance where various memory bank
  types will be provided by independent providers (e.g., a vector store
  is served by Chroma while a keyvalue memory can be served by Redis or
  PGVector)

This commit introduces a generalized routing layer for this purpose.

* update `apis_to_serve`

* llama_toolchain -> llama_stack

* Codemod from llama_toolchain -> llama_stack

- added providers/registry
- cleaned up api/ subdirectories and moved impls away
- restructured api/api.py
- from llama_stack.apis.<api> import foo should work now
- update imports to do llama_stack.apis.<api>
- update many other imports
- added __init__, fixed some registry imports
- updated registry imports
- create_agentic_system -> create_agent
- AgenticSystem -> Agent

* Moved some stuff out of common/; re-generated OpenAPI spec

* llama-toolchain -> llama-stack (hyphens)

* add control plane API

* add redis adapter + sqlite provider

* move core -> distribution

* Some more toolchain -> stack changes

* small naming shenanigans

* Removing custom tool and agent utilities and moving them client side

* Move control plane to distribution server for now

* Remove control plane from API list

* no codeshield dependency randomly plzzzzz

* Add "fire" as a dependency

* add back event loggers

* stack configure fixes

* use brave instead of bing in the example client

* add init file so it gets packaged

* add init files so it gets packaged

* Update MANIFEST

* bug fix

---------

Co-authored-by: Hardik Shah <hjshah@fb.com>
Co-authored-by: Xi Yan <xiyan@meta.com>
Co-authored-by: Ashwin Bharambe <ashwin@meta.com>
This commit is contained in:
Ashwin Bharambe 2024-09-17 19:51:35 -07:00 committed by GitHub
parent f294eac5f5
commit 9487ad8294
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
213 changed files with 1725 additions and 1204 deletions

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SafetyConfig
async def get_provider_impl(config: SafetyConfig, _deps):
from .safety import MetaReferenceSafetyImpl
assert isinstance(config, SafetyConfig), f"Unexpected config type: {type(config)}"
impl = MetaReferenceSafetyImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_models.sku_list import CoreModelId, safety_models
from pydantic import BaseModel, validator
class LlamaGuardShieldConfig(BaseModel):
model: str = "Llama-Guard-3-8B"
excluded_categories: List[str] = []
disable_input_check: bool = False
disable_output_check: bool = False
@validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if m.core_model_id == CoreModelId.llama_guard_3_8b
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class PromptGuardShieldConfig(BaseModel):
model: str = "Prompt-Guard-86M"
@validator("model")
@classmethod
def validate_model(cls, model: str) -> str:
permitted_models = [
m.descriptor()
for m in safety_models()
if m.core_model_id == CoreModelId.prompt_guard_86m
]
if model not in permitted_models:
raise ValueError(
f"Invalid model: {model}. Must be one of {permitted_models}"
)
return model
class SafetyConfig(BaseModel):
llama_guard_shield: Optional[LlamaGuardShieldConfig] = None
prompt_guard_shield: Optional[PromptGuardShieldConfig] = None

View file

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from llama_models.sku_list import resolve_model
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.safety import * # noqa
from .config import SafetyConfig
from .shields import (
CodeScannerShield,
InjectionShield,
JailbreakShield,
LlamaGuardShield,
PromptGuardShield,
ShieldBase,
ThirdPartyShield,
)
def resolve_and_get_path(model_name: str) -> str:
model = resolve_model(model_name)
assert model is not None, f"Could not resolve model {model_name}"
model_dir = model_local_dir(model.descriptor())
return model_dir
class MetaReferenceSafetyImpl(Safety):
def __init__(self, config: SafetyConfig) -> None:
self.config = config
async def initialize(self) -> None:
shield_cfg = self.config.llama_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = LlamaGuardShield.instance(
model_dir=model_dir,
excluded_categories=shield_cfg.excluded_categories,
disable_input_check=shield_cfg.disable_input_check,
disable_output_check=shield_cfg.disable_output_check,
)
shield_cfg = self.config.prompt_guard_shield
if shield_cfg is not None:
model_dir = resolve_and_get_path(shield_cfg.model)
_ = PromptGuardShield.instance(model_dir)
async def run_shields(
self,
messages: List[Message],
shields: List[ShieldDefinition],
) -> RunShieldResponse:
shields = [shield_config_to_shield(c, self.config) for c in shields]
responses = await asyncio.gather(*[shield.run(messages) for shield in shields])
return RunShieldResponse(responses=responses)
def shield_type_equals(a: ShieldType, b: ShieldType):
return a == b or a == b.value
def shield_config_to_shield(
sc: ShieldDefinition, safety_config: SafetyConfig
) -> ShieldBase:
if shield_type_equals(sc.shield_type, BuiltinShield.llama_guard):
assert (
safety_config.llama_guard_shield is not None
), "Cannot use LlamaGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.llama_guard_shield.model)
return LlamaGuardShield.instance(model_dir=model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.jailbreak_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use Jailbreak Shield since Prompt Guard not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return JailbreakShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.injection_shield):
assert (
safety_config.prompt_guard_shield is not None
), "Cannot use PromptGuardShield since not present in config"
model_dir = resolve_and_get_path(safety_config.prompt_guard_shield.model)
return InjectionShield.instance(model_dir)
elif shield_type_equals(sc.shield_type, BuiltinShield.code_scanner_guard):
return CodeScannerShield.instance()
elif shield_type_equals(sc.shield_type, BuiltinShield.third_party_shield):
return ThirdPartyShield.instance()
else:
raise ValueError(f"Unknown shield type: {sc.shield_type}")

View file

@ -0,0 +1,34 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# supress warnings and spew of logs from hugging face
import transformers
from .base import ( # noqa: F401
DummyShield,
OnViolationAction,
ShieldBase,
ShieldResponse,
TextShield,
)
from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import ( # noqa: F401
InjectionShield,
JailbreakShield,
PromptGuardShield,
)
transformers.logging.set_verbosity_error()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import warnings
warnings.filterwarnings("ignore")

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import List
from llama_models.llama3.api.datatypes import interleaved_text_media_as_str, Message
from llama_stack.apis.safety import * # noqa: F403
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class ShieldBase(ABC):
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
def message_content_as_str(message: Message) -> str:
return interleaved_text_media_as_str(message.content)
# For shields that operate on simple strings
class TextShield(ShieldBase):
def convert_messages_to_text(self, messages: List[Message]) -> str:
return "\n".join([message_content_as_str(m) for m in messages])
async def run(self, messages: List[Message]) -> ShieldResponse:
text = self.convert_messages_to_text(messages)
return await self.run_impl(text)
@abstractmethod
async def run_impl(self, text: str) -> ShieldResponse:
raise NotImplementedError()
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)

View file

@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from codeshield.cs import CodeShield
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
cprint(f"Running CodeScannerShield on {text[50:]}", color="magenta")
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
),
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import Message
from llama_stack.safety.meta_reference.shields.base import (
OnViolationAction,
ShieldBase,
ShieldResponse,
)
_INSTANCE = None
class ThirdPartyShield(ShieldBase):
@staticmethod
def instance(on_violation_action=OnViolationAction.RAISE) -> "ThirdPartyShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = ThirdPartyShield(on_violation_action)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
async def run(self, messages: List[Message]) -> ShieldResponse:
super.run() # will raise NotImplementedError

View file

@ -0,0 +1,248 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from string import Template
from typing import List, Optional
import torch
from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.apis.safety import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
CAT_VIOLENT_CRIMES = "Violent Crimes"
CAT_NON_VIOLENT_CRIMES = "Non-Violent Crimes"
CAT_SEX_CRIMES = "Sex Crimes"
CAT_CHILD_EXPLOITATION = "Child Exploitation"
CAT_DEFAMATION = "Defamation"
CAT_SPECIALIZED_ADVICE = "Specialized Advice"
CAT_PRIVACY = "Privacy"
CAT_INTELLECTUAL_PROPERTY = "Intellectual Property"
CAT_INDISCRIMINATE_WEAPONS = "Indiscriminate Weapons"
CAT_HATE = "Hate"
CAT_SELF_HARM = "Self-Harm"
CAT_SEXUAL_CONTENT = "Sexual Content"
CAT_ELECTIONS = "Elections"
CAT_CODE_INTERPRETER_ABUSE = "Code Interpreter Abuse"
SAFETY_CATEGORIES_TO_CODE_MAP = {
CAT_VIOLENT_CRIMES: "S1",
CAT_NON_VIOLENT_CRIMES: "S2",
CAT_SEX_CRIMES: "S3",
CAT_CHILD_EXPLOITATION: "S4",
CAT_DEFAMATION: "S5",
CAT_SPECIALIZED_ADVICE: "S6",
CAT_PRIVACY: "S7",
CAT_INTELLECTUAL_PROPERTY: "S8",
CAT_INDISCRIMINATE_WEAPONS: "S9",
CAT_HATE: "S10",
CAT_SELF_HARM: "S11",
CAT_SEXUAL_CONTENT: "S12",
CAT_ELECTIONS: "S13",
CAT_CODE_INTERPRETER_ABUSE: "S14",
}
DEFAULT_LG_V3_SAFETY_CATEGORIES = [
CAT_VIOLENT_CRIMES,
CAT_NON_VIOLENT_CRIMES,
CAT_SEX_CRIMES,
CAT_CHILD_EXPLOITATION,
CAT_DEFAMATION,
CAT_SPECIALIZED_ADVICE,
CAT_PRIVACY,
CAT_INTELLECTUAL_PROPERTY,
CAT_INDISCRIMINATE_WEAPONS,
CAT_HATE,
CAT_SELF_HARM,
CAT_SEXUAL_CONTENT,
CAT_ELECTIONS,
CAT_CODE_INTERPRETER_ABUSE,
]
PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
SAFETY_CATEGORIES = """
<BEGIN UNSAFE CONTENT CATEGORIES>
$categories
<END UNSAFE CONTENT CATEGORIES>
"""
PROMPT_CONVERSATION = """
<BEGIN CONVERSATION>
$conversations
<END CONVERSATION>
"""
PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template(
f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}"
)
class LlamaGuardShield(ShieldBase):
@staticmethod
def instance(
on_violation_action=OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
) -> "LlamaGuardShield":
global _INSTANCE
if _INSTANCE is None:
_INSTANCE = LlamaGuardShield(
on_violation_action,
model_dir,
excluded_categories,
disable_input_check,
disable_output_check,
)
return _INSTANCE
def __init__(
self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None,
excluded_categories: List[str] = None,
disable_input_check: bool = False,
disable_output_check: bool = False,
):
super().__init__(on_violation_action)
dtype = torch.bfloat16
assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
self.device = "cuda"
self.excluded_categories = excluded_categories
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
# extracts the unsafe code
extracted = match.group(1)
return extracted
return None
def get_safety_categories(self) -> List[str]:
excluded_categories = self.excluded_categories
if set(excluded_categories) == set(SAFETY_CATEGORIES_TO_CODE_MAP.values()):
excluded_categories = []
categories = []
for cat in DEFAULT_LG_V3_SAFETY_CATEGORIES:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories:
continue
categories.append(f"{cat_code}: {cat}.")
return categories
def build_prompt(self, messages: List[Message]) -> str:
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
)
raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import auto, Enum
from typing import List
import torch
from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class PromptGuardShield(TextShield):
class Mode(Enum):
INJECTION = auto()
JAILBREAK = auto()
_instances = {}
_model_cache = None
@staticmethod
def instance(
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action=OnViolationAction.RAISE,
) -> "PromptGuardShield":
action_value = on_violation_action.value
key = (model_dir, threshold, temperature, mode, action_value)
if key not in PromptGuardShield._instances:
PromptGuardShield._instances[key] = PromptGuardShield(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=mode,
on_violation_action=on_violation_action,
)
return PromptGuardShield._instances[key]
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
mode: "PromptGuardShield.Mode" = Mode.JAILBREAK,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(on_violation_action)
assert (
model_dir is not None
), "Must provide a model directory for prompt injection shield"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.device = "cuda"
if PromptGuardShield._model_cache is None:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(
model_dir, device_map=self.device
)
PromptGuardShield._model_cache = (tokenizer, model)
self.tokenizer, self.model = PromptGuardShield._model_cache
self.temperature = temperature
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
async def run_impl(self, text: str) -> ShieldResponse:
# run model on messages and return response
inputs = self.tokenizer(text, return_tensors="pt")
inputs = {name: tensor.to(self.model.device) for name, tensor in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs[0]
probabilities = torch.softmax(logits / self.temperature, dim=-1)
score_embedded = probabilities[0, 1].item()
score_malicious = probabilities[0, 2].item()
cprint(
f"Ran PromptGuardShield and got Scores: Embedded: {score_embedded}, Malicious: {score_malicious}",
color="magenta",
)
if self.mode == self.Mode.INJECTION and (
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)