Update safety implementation inside agents

This commit is contained in:
Ashwin Bharambe 2024-09-20 14:27:54 -07:00 committed by Xi Yan
parent 82ddd851c8
commit d6a41d98d2
8 changed files with 17 additions and 67 deletions

View file

@ -208,7 +208,7 @@ class ToolExecutionStep(StepCommon):
@json_schema_type
class ShieldCallStep(StepCommon):
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
response: ShieldResponse
violation: Optional[SafetyViolation]
@json_schema_type

View file

@ -25,7 +25,7 @@ class BedrockSafetyAdapter(Safety):
async def run_shield(
self,
shield: ShieldType,
shield: str,
messages: List[Message],
) -> RunShieldResponse:
# clients will set api_keys by doing something like:

View file

@ -94,12 +94,11 @@ class ChatAgent(ShieldRunnerMixin):
)
)
elif step.step_type == StepType.shield_call.value:
response = step.response
if response.is_violation:
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=response.violation_return_message,
content=violation.user_message,
stop_reason=StopReason.end_of_turn,
)
)
@ -276,7 +275,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)
@ -295,12 +294,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -550,12 +544,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=ShieldResponse(
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
violation=None,
),
)
)
@ -569,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=str(uuid.uuid4()),
turn_id=turn_id,
response=e.response,
violation=e.violation,
),
)
)

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
from typing import AsyncIterator, List, Optional, Union
from unittest.mock import MagicMock
import pytest
@ -80,9 +79,9 @@ class MockInferenceAPI:
class MockSafetyAPI:
async def run_shields(
self, messages: List[Message], shields: List[MagicMock]
) -> List[ShieldResponse]:
return [ShieldResponse(shield_type="mock_shield", is_violation=False)]
self, shield_type: str, messages: List[Message]
) -> RunShieldResponse:
return RunShieldResponse(violation=None)
class MockMemoryAPI:

View file

@ -35,10 +35,6 @@ class ShieldBase(ABC):
):
self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError()
@ -63,11 +59,6 @@ class TextShield(ShieldBase):
class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e
return ShieldResponse(
shield_type=BuiltinShield.third_party_shield, is_violation=False
)
return ShieldResponse(is_violation=False)

View file

@ -7,13 +7,9 @@
from termcolor import cprint
from .base import ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield
@ -21,7 +17,6 @@ class CodeScannerShield(TextShield):
result = await CodeShield.scan_code(text)
if result.is_insecure:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True,
violation_type=",".join(
[issue.pattern_id for issue in result.issues_found]
@ -29,6 +24,4 @@ class CodeScannerShield(TextShield):
violation_return_message="Sorry, I found security concerns in the code.",
)
else:
return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)
return ShieldResponse(is_violation=False)

View file

@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.apis.safety import * # noqa: F403
SAFE_RESPONSE = "safe"
_INSTANCE = None
@ -152,9 +152,6 @@ class LlamaGuardShield(ShieldBase):
model_dir, torch_dtype=torch_dtype, device_map=self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -192,18 +189,13 @@ class LlamaGuardShield(ShieldBase):
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
@ -213,12 +205,9 @@ class LlamaGuardShield(ShieldBase):
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else:

View file

@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import Message
from termcolor import cprint
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class PromptGuardShield(TextShield):
@ -74,13 +73,6 @@ class PromptGuardShield(TextShield):
self.threshold = threshold
self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1])
@ -103,21 +95,18 @@ class PromptGuardShield(TextShield):
score_embedded + score_malicious > self.threshold
):
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.",
)
return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False,
)