From 9c9b834c0f1279386e6fc3507b6862ab941a8c64 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Fri, 19 Jul 2024 18:12:09 -0700 Subject: [PATCH] update prompt-shield to reflect latest changes in agentic --- toolchain/safety/api/datatypes.py | 2 ++ toolchain/safety/shields/__init__.py | 2 +- toolchain/safety/shields/prompt_guard.py | 46 +++++++++++++++++++++--- 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/toolchain/safety/api/datatypes.py b/toolchain/safety/api/datatypes.py index 72541d10f..879358eab 100644 --- a/toolchain/safety/api/datatypes.py +++ b/toolchain/safety/api/datatypes.py @@ -16,6 +16,8 @@ class BuiltinShield(Enum): prompt_guard = "prompt_guard" code_scanner_guard = "code_scanner_guard" third_party_shield = "third_party_shield" + injection_shield = "injection_shield" + jailbreak_shield = "jailbreak_shield" ShieldType = Union[BuiltinShield, str] diff --git a/toolchain/safety/shields/__init__.py b/toolchain/safety/shields/__init__.py index d9ee5ea38..ed2309221 100644 --- a/toolchain/safety/shields/__init__.py +++ b/toolchain/safety/shields/__init__.py @@ -11,7 +11,7 @@ from .base import ( # noqa: F401 from .code_scanner import CodeScannerShield # noqa: F401 from .contrib.third_party_shield import ThirdPartyShield # noqa: F401 from .llama_guard import LlamaGuardShield # noqa: F401 -from .prompt_guard import PromptGuardShield # noqa: F401 +from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401 from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401 transformers.logging.set_verbosity_error() diff --git a/toolchain/safety/shields/prompt_guard.py b/toolchain/safety/shields/prompt_guard.py index 4cab03b36..f67ee71c1 100644 --- a/toolchain/safety/shields/prompt_guard.py +++ b/toolchain/safety/shields/prompt_guard.py @@ -69,7 +69,11 @@ class PromptGuardShield(TextShield): self.mode = mode def get_shield_type(self) -> ShieldType: - return BuiltinShield.prompt_guard + return ( + BuiltinShield.jailbreak_shield + if self.mode == self.Mode.JAILBREAK + else BuiltinShield.injection_shield + ) def convert_messages_to_text(self, messages: List[Message]) -> str: return message_content_as_str(messages[-1]) @@ -93,20 +97,54 @@ class PromptGuardShield(TextShield): score_embedded + score_malicious > self.threshold ): return ShieldResponse( - shield_type=BuiltinShield.prompt_guard, + shield_type=self.get_shield_type(), is_violation=True, violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", violation_return_message="Sorry, I cannot do this.", ) elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: return ShieldResponse( - shield_type=BuiltinShield.prompt_guard, + shield_type=self.get_shield_type(), is_violation=True, violation_type=f"prompt_injection:malicious={score_malicious}", violation_return_message="Sorry, I cannot do this.", ) return ShieldResponse( - shield_type=BuiltinShield.prompt_guard, + shield_type=self.get_shield_type(), is_violation=False, ) + + +class JailbreakShield(PromptGuardShield): + def __init__( + self, + model_dir: str, + threshold: float = 0.9, + temperature: float = 1.0, + on_violation_action: OnViolationAction = OnViolationAction.RAISE, + ): + super().__init__( + model_dir=model_dir, + threshold=threshold, + temperature=temperature, + mode=PromptGuardShield.Mode.JAILBREAK, + on_violation_action=on_violation_action, + ) + + +class InjectionShield(PromptGuardShield): + def __init__( + self, + model_dir: str, + threshold: float = 0.9, + temperature: float = 1.0, + on_violation_action: OnViolationAction = OnViolationAction.RAISE, + ): + super().__init__( + model_dir=model_dir, + threshold=threshold, + temperature=temperature, + mode=PromptGuardShield.Mode.INJECTION, + on_violation_action=on_violation_action, + )