update prompt-shield to reflect latest changes in agentic

This commit is contained in:
Hardik Shah 2024-07-19 18:12:09 -07:00
parent ce0804556b
commit 9c9b834c0f
3 changed files with 45 additions and 5 deletions

View file

@ -16,6 +16,8 @@ class BuiltinShield(Enum):
prompt_guard = "prompt_guard" prompt_guard = "prompt_guard"
code_scanner_guard = "code_scanner_guard" code_scanner_guard = "code_scanner_guard"
third_party_shield = "third_party_shield" third_party_shield = "third_party_shield"
injection_shield = "injection_shield"
jailbreak_shield = "jailbreak_shield"
ShieldType = Union[BuiltinShield, str] ShieldType = Union[BuiltinShield, str]

View file

@ -11,7 +11,7 @@ from .base import ( # noqa: F401
from .code_scanner import CodeScannerShield # noqa: F401 from .code_scanner import CodeScannerShield # noqa: F401
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401 from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
from .llama_guard import LlamaGuardShield # noqa: F401 from .llama_guard import LlamaGuardShield # noqa: F401
from .prompt_guard import PromptGuardShield # noqa: F401 from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401 from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()

View file

@ -69,7 +69,11 @@ class PromptGuardShield(TextShield):
self.mode = mode self.mode = mode
def get_shield_type(self) -> ShieldType: def get_shield_type(self) -> ShieldType:
return BuiltinShield.prompt_guard return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str: def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1]) return message_content_as_str(messages[-1])
@ -93,20 +97,54 @@ class PromptGuardShield(TextShield):
score_embedded + score_malicious > self.threshold score_embedded + score_malicious > self.threshold
): ):
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.prompt_guard, shield_type=self.get_shield_type(),
is_violation=True, is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.", violation_return_message="Sorry, I cannot do this.",
) )
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.prompt_guard, shield_type=self.get_shield_type(),
is_violation=True, is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}", violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.", violation_return_message="Sorry, I cannot do this.",
) )
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.prompt_guard, shield_type=self.get_shield_type(),
is_violation=False, is_violation=False,
) )
class JailbreakShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.JAILBREAK,
on_violation_action=on_violation_action,
)
class InjectionShield(PromptGuardShield):
def __init__(
self,
model_dir: str,
threshold: float = 0.9,
temperature: float = 1.0,
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
):
super().__init__(
model_dir=model_dir,
threshold=threshold,
temperature=temperature,
mode=PromptGuardShield.Mode.INJECTION,
on_violation_action=on_violation_action,
)