mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Update safety implementation inside agents
This commit is contained in:
parent
82ddd851c8
commit
d6a41d98d2
8 changed files with 17 additions and 67 deletions
|
@ -208,7 +208,7 @@ class ToolExecutionStep(StepCommon):
|
|||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||
response: ShieldResponse
|
||||
violation: Optional[SafetyViolation]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -25,7 +25,7 @@ class BedrockSafetyAdapter(Safety):
|
|||
|
||||
async def run_shield(
|
||||
self,
|
||||
shield: ShieldType,
|
||||
shield: str,
|
||||
messages: List[Message],
|
||||
) -> RunShieldResponse:
|
||||
# clients will set api_keys by doing something like:
|
||||
|
|
|
@ -94,12 +94,11 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
response = step.response
|
||||
if response.is_violation:
|
||||
if step.violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=response.violation_return_message,
|
||||
content=violation.user_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
|
@ -276,7 +275,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
response=e.response,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -295,12 +294,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
response=ShieldResponse(
|
||||
# TODO: fix this, give each shield a shield type method and
|
||||
# fire one event for each shield run
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
),
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -550,12 +544,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
turn_id=turn_id,
|
||||
response=ShieldResponse(
|
||||
# TODO: fix this, give each shield a shield type method and
|
||||
# fire one event for each shield run
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
),
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
@ -569,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=str(uuid.uuid4()),
|
||||
turn_id=turn_id,
|
||||
response=e.response,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -80,9 +79,9 @@ class MockInferenceAPI:
|
|||
|
||||
class MockSafetyAPI:
|
||||
async def run_shields(
|
||||
self, messages: List[Message], shields: List[MagicMock]
|
||||
) -> List[ShieldResponse]:
|
||||
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]
|
||||
self, shield_type: str, messages: List[Message]
|
||||
) -> RunShieldResponse:
|
||||
return RunShieldResponse(violation=None)
|
||||
|
||||
|
||||
class MockMemoryAPI:
|
||||
|
|
|
@ -35,10 +35,6 @@ class ShieldBase(ABC):
|
|||
):
|
||||
self.on_violation_action = on_violation_action
|
||||
|
||||
@abstractmethod
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
raise NotImplementedError()
|
||||
|
@ -63,11 +59,6 @@ class TextShield(ShieldBase):
|
|||
|
||||
|
||||
class DummyShield(TextShield):
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return "dummy"
|
||||
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
# Dummy return LOW to test e2e
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.third_party_shield, is_violation=False
|
||||
)
|
||||
return ShieldResponse(is_violation=False)
|
||||
|
|
|
@ -7,13 +7,9 @@
|
|||
from termcolor import cprint
|
||||
|
||||
from .base import ShieldResponse, TextShield
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class CodeScannerShield(TextShield):
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return BuiltinShield.code_scanner_guard
|
||||
|
||||
async def run_impl(self, text: str) -> ShieldResponse:
|
||||
from codeshield.cs import CodeShield
|
||||
|
||||
|
@ -21,7 +17,6 @@ class CodeScannerShield(TextShield):
|
|||
result = await CodeShield.scan_code(text)
|
||||
if result.is_insecure:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.code_scanner_guard,
|
||||
is_violation=True,
|
||||
violation_type=",".join(
|
||||
[issue.pattern_id for issue in result.issues_found]
|
||||
|
@ -29,6 +24,4 @@ class CodeScannerShield(TextShield):
|
|||
violation_return_message="Sorry, I found security concerns in the code.",
|
||||
)
|
||||
else:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
|
||||
)
|
||||
return ShieldResponse(is_violation=False)
|
||||
|
|
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
|
|||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
SAFE_RESPONSE = "safe"
|
||||
_INSTANCE = None
|
||||
|
@ -152,9 +152,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||
)
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return BuiltinShield.llama_guard
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
|
@ -192,18 +189,13 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||
if response == SAFE_RESPONSE:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
return ShieldResponse(is_violation=False)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
return ShieldResponse(is_violation=False)
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=True,
|
||||
violation_type=unsafe_code,
|
||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||
|
@ -213,12 +205,9 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import Message
|
|||
from termcolor import cprint
|
||||
|
||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
||||
|
||||
class PromptGuardShield(TextShield):
|
||||
|
@ -74,13 +73,6 @@ class PromptGuardShield(TextShield):
|
|||
self.threshold = threshold
|
||||
self.mode = mode
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return (
|
||||
BuiltinShield.jailbreak_shield
|
||||
if self.mode == self.Mode.JAILBREAK
|
||||
else BuiltinShield.injection_shield
|
||||
)
|
||||
|
||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||
return message_content_as_str(messages[-1])
|
||||
|
||||
|
@ -103,21 +95,18 @@ class PromptGuardShield(TextShield):
|
|||
score_embedded + score_malicious > self.threshold
|
||||
):
|
||||
return ShieldResponse(
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||
return ShieldResponse(
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=True,
|
||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||
violation_return_message="Sorry, I cannot do this.",
|
||||
)
|
||||
|
||||
return ShieldResponse(
|
||||
shield_type=self.get_shield_type(),
|
||||
is_violation=False,
|
||||
)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue