mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
update prompt-shield to reflect latest changes in agentic
This commit is contained in:
parent
ce0804556b
commit
9c9b834c0f
3 changed files with 45 additions and 5 deletions
|
@ -16,6 +16,8 @@ class BuiltinShield(Enum):
|
||||||
prompt_guard = "prompt_guard"
|
prompt_guard = "prompt_guard"
|
||||||
code_scanner_guard = "code_scanner_guard"
|
code_scanner_guard = "code_scanner_guard"
|
||||||
third_party_shield = "third_party_shield"
|
third_party_shield = "third_party_shield"
|
||||||
|
injection_shield = "injection_shield"
|
||||||
|
jailbreak_shield = "jailbreak_shield"
|
||||||
|
|
||||||
|
|
||||||
ShieldType = Union[BuiltinShield, str]
|
ShieldType = Union[BuiltinShield, str]
|
||||||
|
|
|
@ -11,7 +11,7 @@ from .base import ( # noqa: F401
|
||||||
from .code_scanner import CodeScannerShield # noqa: F401
|
from .code_scanner import CodeScannerShield # noqa: F401
|
||||||
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
from .contrib.third_party_shield import ThirdPartyShield # noqa: F401
|
||||||
from .llama_guard import LlamaGuardShield # noqa: F401
|
from .llama_guard import LlamaGuardShield # noqa: F401
|
||||||
from .prompt_guard import PromptGuardShield # noqa: F401
|
from .prompt_guard import InjectionShield, JailbreakShield, PromptGuardShield # noqa: F401
|
||||||
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
from .shield_runner import SafetyException, ShieldRunnerMixin # noqa: F401
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
|
@ -69,7 +69,11 @@ class PromptGuardShield(TextShield):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def get_shield_type(self) -> ShieldType:
|
def get_shield_type(self) -> ShieldType:
|
||||||
return BuiltinShield.prompt_guard
|
return (
|
||||||
|
BuiltinShield.jailbreak_shield
|
||||||
|
if self.mode == self.Mode.JAILBREAK
|
||||||
|
else BuiltinShield.injection_shield
|
||||||
|
)
|
||||||
|
|
||||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||||
return message_content_as_str(messages[-1])
|
return message_content_as_str(messages[-1])
|
||||||
|
@ -93,20 +97,54 @@ class PromptGuardShield(TextShield):
|
||||||
score_embedded + score_malicious > self.threshold
|
score_embedded + score_malicious > self.threshold
|
||||||
):
|
):
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=BuiltinShield.prompt_guard,
|
shield_type=self.get_shield_type(),
|
||||||
is_violation=True,
|
is_violation=True,
|
||||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||||
violation_return_message="Sorry, I cannot do this.",
|
violation_return_message="Sorry, I cannot do this.",
|
||||||
)
|
)
|
||||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=BuiltinShield.prompt_guard,
|
shield_type=self.get_shield_type(),
|
||||||
is_violation=True,
|
is_violation=True,
|
||||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||||
violation_return_message="Sorry, I cannot do this.",
|
violation_return_message="Sorry, I cannot do this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=BuiltinShield.prompt_guard,
|
shield_type=self.get_shield_type(),
|
||||||
is_violation=False,
|
is_violation=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class JailbreakShield(PromptGuardShield):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir: str,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_dir=model_dir,
|
||||||
|
threshold=threshold,
|
||||||
|
temperature=temperature,
|
||||||
|
mode=PromptGuardShield.Mode.JAILBREAK,
|
||||||
|
on_violation_action=on_violation_action,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InjectionShield(PromptGuardShield):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_dir: str,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model_dir=model_dir,
|
||||||
|
threshold=threshold,
|
||||||
|
temperature=temperature,
|
||||||
|
mode=PromptGuardShield.Mode.INJECTION,
|
||||||
|
on_violation_action=on_violation_action,
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue