mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Update safety implementation inside agents
This commit is contained in:
parent
51245a417b
commit
6e0f283f52
8 changed files with 17 additions and 67 deletions
|
@ -208,7 +208,7 @@ class ToolExecutionStep(StepCommon):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class ShieldCallStep(StepCommon):
|
class ShieldCallStep(StepCommon):
|
||||||
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
|
||||||
response: ShieldResponse
|
violation: Optional[SafetyViolation]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -25,7 +25,7 @@ class BedrockSafetyAdapter(Safety):
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
shield: ShieldType,
|
shield: str,
|
||||||
messages: List[Message],
|
messages: List[Message],
|
||||||
) -> RunShieldResponse:
|
) -> RunShieldResponse:
|
||||||
# clients will set api_keys by doing something like:
|
# clients will set api_keys by doing something like:
|
||||||
|
|
|
@ -94,12 +94,11 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif step.step_type == StepType.shield_call.value:
|
elif step.step_type == StepType.shield_call.value:
|
||||||
response = step.response
|
if step.violation:
|
||||||
if response.is_violation:
|
|
||||||
# CompletionMessage itself in the ShieldResponse
|
# CompletionMessage itself in the ShieldResponse
|
||||||
messages.append(
|
messages.append(
|
||||||
CompletionMessage(
|
CompletionMessage(
|
||||||
content=response.violation_return_message,
|
content=violation.user_message,
|
||||||
stop_reason=StopReason.end_of_turn,
|
stop_reason=StopReason.end_of_turn,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -276,7 +275,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
response=e.response,
|
violation=e.violation,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -295,12 +294,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
response=ShieldResponse(
|
violation=None,
|
||||||
# TODO: fix this, give each shield a shield type method and
|
|
||||||
# fire one event for each shield run
|
|
||||||
shield_type=BuiltinShield.llama_guard,
|
|
||||||
is_violation=False,
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -550,12 +544,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=str(uuid.uuid4()),
|
step_id=str(uuid.uuid4()),
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
response=ShieldResponse(
|
violation=None,
|
||||||
# TODO: fix this, give each shield a shield type method and
|
|
||||||
# fire one event for each shield run
|
|
||||||
shield_type=BuiltinShield.llama_guard,
|
|
||||||
is_violation=False,
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -569,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
step_details=ShieldCallStep(
|
step_details=ShieldCallStep(
|
||||||
step_id=str(uuid.uuid4()),
|
step_id=str(uuid.uuid4()),
|
||||||
turn_id=turn_id,
|
turn_id=turn_id,
|
||||||
response=e.response,
|
violation=e.violation,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncIterator, List, Optional, Union
|
from typing import AsyncIterator, List, Optional, Union
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -80,9 +79,9 @@ class MockInferenceAPI:
|
||||||
|
|
||||||
class MockSafetyAPI:
|
class MockSafetyAPI:
|
||||||
async def run_shields(
|
async def run_shields(
|
||||||
self, messages: List[Message], shields: List[MagicMock]
|
self, shield_type: str, messages: List[Message]
|
||||||
) -> List[ShieldResponse]:
|
) -> RunShieldResponse:
|
||||||
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]
|
return RunShieldResponse(violation=None)
|
||||||
|
|
||||||
|
|
||||||
class MockMemoryAPI:
|
class MockMemoryAPI:
|
||||||
|
|
|
@ -35,10 +35,6 @@ class ShieldBase(ABC):
|
||||||
):
|
):
|
||||||
self.on_violation_action = on_violation_action
|
self.on_violation_action = on_violation_action
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_shield_type(self) -> ShieldType:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
@ -63,11 +59,6 @@ class TextShield(ShieldBase):
|
||||||
|
|
||||||
|
|
||||||
class DummyShield(TextShield):
|
class DummyShield(TextShield):
|
||||||
def get_shield_type(self) -> ShieldType:
|
|
||||||
return "dummy"
|
|
||||||
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
# Dummy return LOW to test e2e
|
# Dummy return LOW to test e2e
|
||||||
return ShieldResponse(
|
return ShieldResponse(is_violation=False)
|
||||||
shield_type=BuiltinShield.third_party_shield, is_violation=False
|
|
||||||
)
|
|
||||||
|
|
|
@ -7,13 +7,9 @@
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .base import ShieldResponse, TextShield
|
from .base import ShieldResponse, TextShield
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class CodeScannerShield(TextShield):
|
class CodeScannerShield(TextShield):
|
||||||
def get_shield_type(self) -> ShieldType:
|
|
||||||
return BuiltinShield.code_scanner_guard
|
|
||||||
|
|
||||||
async def run_impl(self, text: str) -> ShieldResponse:
|
async def run_impl(self, text: str) -> ShieldResponse:
|
||||||
from codeshield.cs import CodeShield
|
from codeshield.cs import CodeShield
|
||||||
|
|
||||||
|
@ -21,7 +17,6 @@ class CodeScannerShield(TextShield):
|
||||||
result = await CodeShield.scan_code(text)
|
result = await CodeShield.scan_code(text)
|
||||||
if result.is_insecure:
|
if result.is_insecure:
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=BuiltinShield.code_scanner_guard,
|
|
||||||
is_violation=True,
|
is_violation=True,
|
||||||
violation_type=",".join(
|
violation_type=",".join(
|
||||||
[issue.pattern_id for issue in result.issues_found]
|
[issue.pattern_id for issue in result.issues_found]
|
||||||
|
@ -29,6 +24,4 @@ class CodeScannerShield(TextShield):
|
||||||
violation_return_message="Sorry, I found security concerns in the code.",
|
violation_return_message="Sorry, I found security concerns in the code.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return ShieldResponse(
|
return ShieldResponse(is_violation=False)
|
||||||
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
|
|
||||||
)
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
|
|
||||||
SAFE_RESPONSE = "safe"
|
SAFE_RESPONSE = "safe"
|
||||||
_INSTANCE = None
|
_INSTANCE = None
|
||||||
|
@ -152,9 +152,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_shield_type(self) -> ShieldType:
|
|
||||||
return BuiltinShield.llama_guard
|
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
if match:
|
if match:
|
||||||
|
@ -192,18 +189,13 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||||
if response == SAFE_RESPONSE:
|
if response == SAFE_RESPONSE:
|
||||||
return ShieldResponse(
|
return ShieldResponse(is_violation=False)
|
||||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
|
||||||
)
|
|
||||||
unsafe_code = self.check_unsafe_response(response)
|
unsafe_code = self.check_unsafe_response(response)
|
||||||
if unsafe_code:
|
if unsafe_code:
|
||||||
unsafe_code_list = unsafe_code.split(",")
|
unsafe_code_list = unsafe_code.split(",")
|
||||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||||
return ShieldResponse(
|
return ShieldResponse(is_violation=False)
|
||||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
|
||||||
)
|
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=BuiltinShield.llama_guard,
|
|
||||||
is_violation=True,
|
is_violation=True,
|
||||||
violation_type=unsafe_code,
|
violation_type=unsafe_code,
|
||||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||||
|
@ -213,12 +205,9 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||||
return ShieldResponse(
|
return ShieldResponse(is_violation=False)
|
||||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
|
||||||
)
|
|
||||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=BuiltinShield.llama_guard,
|
|
||||||
is_violation=False,
|
is_violation=False,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import Message
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardShield(TextShield):
|
class PromptGuardShield(TextShield):
|
||||||
|
@ -74,13 +73,6 @@ class PromptGuardShield(TextShield):
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def get_shield_type(self) -> ShieldType:
|
|
||||||
return (
|
|
||||||
BuiltinShield.jailbreak_shield
|
|
||||||
if self.mode == self.Mode.JAILBREAK
|
|
||||||
else BuiltinShield.injection_shield
|
|
||||||
)
|
|
||||||
|
|
||||||
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
def convert_messages_to_text(self, messages: List[Message]) -> str:
|
||||||
return message_content_as_str(messages[-1])
|
return message_content_as_str(messages[-1])
|
||||||
|
|
||||||
|
@ -103,21 +95,18 @@ class PromptGuardShield(TextShield):
|
||||||
score_embedded + score_malicious > self.threshold
|
score_embedded + score_malicious > self.threshold
|
||||||
):
|
):
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=self.get_shield_type(),
|
|
||||||
is_violation=True,
|
is_violation=True,
|
||||||
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
|
||||||
violation_return_message="Sorry, I cannot do this.",
|
violation_return_message="Sorry, I cannot do this.",
|
||||||
)
|
)
|
||||||
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=self.get_shield_type(),
|
|
||||||
is_violation=True,
|
is_violation=True,
|
||||||
violation_type=f"prompt_injection:malicious={score_malicious}",
|
violation_type=f"prompt_injection:malicious={score_malicious}",
|
||||||
violation_return_message="Sorry, I cannot do this.",
|
violation_return_message="Sorry, I cannot do this.",
|
||||||
)
|
)
|
||||||
|
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=self.get_shield_type(),
|
|
||||||
is_violation=False,
|
is_violation=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue