mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
update prompt-shield to reflect latest changes in agentic
This commit is contained in:
parent
ce0804556b
commit
9c9b834c0f
3 changed files with 45 additions and 5 deletions
|
@ -16,6 +16,8 @@ class BuiltinShield(Enum):
|
|||
prompt_guard = "prompt_guard"
|
||||
code_scanner_guard = "code_scanner_guard"
|
||||
third_party_shield = "third_party_shield"
|
||||
injection_shield = "injection_shield"
|
||||
jailbreak_shield = "jailbreak_shield"
|
||||
|
||||
|
||||
ShieldType = Union[BuiltinShield, str]
|
||||
|
|
|
@ -11,7 +11,7 @@ from .base import ( # noqa: F401
|
|||
from .code_scanner import CodeScannerShield # noqa: F401
|
||||
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||
from .prompt_guard import PromptGuardShield # noqa: F401
|
||||
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
|
||||
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
|
|
@ -69,7 +69,11 @@ class PromptGuardShield(TextShield):
|
|||
self.mode = mode
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return BuiltinShield.prompt_guard
|
||||
return (
|
||||
BuiltinShield.jailbreak_shield
|
||||
if self.mode == self.Mode.JAILBREAK
|
||||
else BuiltinShield.injection_shield
|
||||
)
|
||||
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return message_content_as_str(messages[-1])
|
||||
|
@ -93,20 +97,54 @@ class PromptGuardShield(TextShield):
|
|||
score_embedded + score_malicious > self.threshold
|
||||
):
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.prompt_guard,
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.prompt_guard,
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.prompt_guard,
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
|
||||
class JailbreakShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.JAILBREAK,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
|
||||
|
||||
class InjectionShield(PromptGuardShield):
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str,
|
||||
threshold: float = 0.9,
|
||||
temperature: float = 1.0,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
):
|
||||
super().__init__(
|
||||
model_dir=model_dir,
|
||||
threshold=threshold,
|
||||
temperature=temperature,
|
||||
mode=PromptGuardShield.Mode.INJECTION,
|
||||
on_violation_action=on_violation_action,
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue