Update safety implementation inside agents

This commit is contained in:
Ashwin Bharambe 2024-09-20 14:27:54 -07:00 committed by Xi Yan
parent 82ddd851c8
commit d6a41d98d2
8 changed files with 17 additions and 67 deletions

View file

@ -208,7 +208,7 @@ class ToolExecutionStep(StepCommon):
@json_schema_type @json_schema_type
class ShieldCallStep(StepCommon): class ShieldCallStep(StepCommon):
step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value
response: ShieldResponse violation: Optional[SafetyViolation]
@json_schema_type @json_schema_type

View file

@ -25,7 +25,7 @@ class BedrockSafetyAdapter(Safety):
async def run_shield( async def run_shield(
self, self,
shield: ShieldType, shield: str,
messages: List[Message], messages: List[Message],
) -> RunShieldResponse: ) -> RunShieldResponse:
# clients will set api_keys by doing something like: # clients will set api_keys by doing something like:

View file

@ -94,12 +94,11 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
elif step.step_type == StepType.shield_call.value: elif step.step_type == StepType.shield_call.value:
response = step.response if step.violation:
if response.is_violation:
# CompletionMessage itself in the ShieldResponse # CompletionMessage itself in the ShieldResponse
messages.append( messages.append(
CompletionMessage( CompletionMessage(
content=response.violation_return_message, content=violation.user_message,
stop_reason=StopReason.end_of_turn, stop_reason=StopReason.end_of_turn,
) )
) )
@ -276,7 +275,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
response=e.response, violation=e.violation,
), ),
) )
) )
@ -295,12 +294,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
response=ShieldResponse( violation=None,
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
), ),
) )
) )
@ -550,12 +544,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=str(uuid.uuid4()), step_id=str(uuid.uuid4()),
turn_id=turn_id, turn_id=turn_id,
response=ShieldResponse( violation=None,
# TODO: fix this, give each shield a shield type method and
# fire one event for each shield run
shield_type=BuiltinShield.llama_guard,
is_violation=False,
),
), ),
) )
) )
@ -569,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep( step_details=ShieldCallStep(
step_id=str(uuid.uuid4()), step_id=str(uuid.uuid4()),
turn_id=turn_id, turn_id=turn_id,
response=e.response, violation=e.violation,
), ),
) )
) )

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncIterator, List, Optional, Union from typing import AsyncIterator, List, Optional, Union
from unittest.mock import MagicMock
import pytest import pytest
@ -80,9 +79,9 @@ class MockInferenceAPI:
class MockSafetyAPI: class MockSafetyAPI:
async def run_shields( async def run_shields(
self, messages: List[Message], shields: List[MagicMock] self, shield_type: str, messages: List[Message]
) -> List[ShieldResponse]: ) -> RunShieldResponse:
return [ShieldResponse(shield_type="mock_shield", is_violation=False)] return RunShieldResponse(violation=None)
class MockMemoryAPI: class MockMemoryAPI:

View file

@ -35,10 +35,6 @@ class ShieldBase(ABC):
): ):
self.on_violation_action = on_violation_action self.on_violation_action = on_violation_action
@abstractmethod
def get_shield_type(self) -> ShieldType:
raise NotImplementedError()
@abstractmethod @abstractmethod
async def run(self, messages: List[Message]) -> ShieldResponse: async def run(self, messages: List[Message]) -> ShieldResponse:
raise NotImplementedError() raise NotImplementedError()
@ -63,11 +59,6 @@ class TextShield(ShieldBase):
class DummyShield(TextShield): class DummyShield(TextShield):
def get_shield_type(self) -> ShieldType:
return "dummy"
async def run_impl(self, text: str) -> ShieldResponse: async def run_impl(self, text: str) -> ShieldResponse:
# Dummy return LOW to test e2e # Dummy return LOW to test e2e
return ShieldResponse( return ShieldResponse(is_violation=False)
shield_type=BuiltinShield.third_party_shield, is_violation=False
)

View file

@ -7,13 +7,9 @@
from termcolor import cprint from termcolor import cprint
from .base import ShieldResponse, TextShield from .base import ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class CodeScannerShield(TextShield): class CodeScannerShield(TextShield):
def get_shield_type(self) -> ShieldType:
return BuiltinShield.code_scanner_guard
async def run_impl(self, text: str) -> ShieldResponse: async def run_impl(self, text: str) -> ShieldResponse:
from codeshield.cs import CodeShield from codeshield.cs import CodeShield
@ -21,7 +17,6 @@ class CodeScannerShield(TextShield):
result = await CodeShield.scan_code(text) result = await CodeShield.scan_code(text)
if result.is_insecure: if result.is_insecure:
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.code_scanner_guard,
is_violation=True, is_violation=True,
violation_type=",".join( violation_type=",".join(
[issue.pattern_id for issue in result.issues_found] [issue.pattern_id for issue in result.issues_found]
@ -29,6 +24,4 @@ class CodeScannerShield(TextShield):
violation_return_message="Sorry, I found security concerns in the code.", violation_return_message="Sorry, I found security concerns in the code.",
) )
else: else:
return ShieldResponse( return ShieldResponse(is_violation=False)
shield_type=BuiltinShield.code_scanner_guard, is_violation=False
)

View file

@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
from llama_stack.apis.safety import * # noqa: F403
SAFE_RESPONSE = "safe" SAFE_RESPONSE = "safe"
_INSTANCE = None _INSTANCE = None
@ -152,9 +152,6 @@ class LlamaGuardShield(ShieldBase):
model_dir, torch_dtype=torch_dtype, device_map=self.device model_dir, torch_dtype=torch_dtype, device_map=self.device
) )
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]: def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
if match: if match:
@ -192,18 +189,13 @@ class LlamaGuardShield(ShieldBase):
def get_shield_response(self, response: str) -> ShieldResponse: def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE: if response == SAFE_RESPONSE:
return ShieldResponse( return ShieldResponse(is_violation=False)
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response) unsafe_code = self.check_unsafe_response(response)
if unsafe_code: if unsafe_code:
unsafe_code_list = unsafe_code.split(",") unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)): if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse( return ShieldResponse(is_violation=False)
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True, is_violation=True,
violation_type=unsafe_code, violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT, violation_return_message=CANNED_RESPONSE_TEXT,
@ -213,12 +205,9 @@ class LlamaGuardShield(ShieldBase):
async def run(self, messages: List[Message]) -> ShieldResponse: async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value: if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse( return ShieldResponse(is_violation=False)
shield_type=BuiltinShield.llama_guard, is_violation=False
)
elif self.disable_output_check and messages[-1].role == Role.assistant.value: elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False, is_violation=False,
) )
else: else:

View file

@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import Message
from termcolor import cprint from termcolor import cprint
from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield
from llama_stack.apis.safety import * # noqa: F403
class PromptGuardShield(TextShield): class PromptGuardShield(TextShield):
@ -74,13 +73,6 @@ class PromptGuardShield(TextShield):
self.threshold = threshold self.threshold = threshold
self.mode = mode self.mode = mode
def get_shield_type(self) -> ShieldType:
return (
BuiltinShield.jailbreak_shield
if self.mode == self.Mode.JAILBREAK
else BuiltinShield.injection_shield
)
def convert_messages_to_text(self, messages: List[Message]) -> str: def convert_messages_to_text(self, messages: List[Message]) -> str:
return message_content_as_str(messages[-1]) return message_content_as_str(messages[-1])
@ -103,21 +95,18 @@ class PromptGuardShield(TextShield):
score_embedded + score_malicious > self.threshold score_embedded + score_malicious > self.threshold
): ):
return ShieldResponse( return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True, is_violation=True,
violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.", violation_return_message="Sorry, I cannot do this.",
) )
elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold:
return ShieldResponse( return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=True, is_violation=True,
violation_type=f"prompt_injection:malicious={score_malicious}", violation_type=f"prompt_injection:malicious={score_malicious}",
violation_return_message="Sorry, I cannot do this.", violation_return_message="Sorry, I cannot do this.",
) )
return ShieldResponse( return ShieldResponse(
shield_type=self.get_shield_type(),
is_violation=False, is_violation=False,
) )