From d6a41d98d2be11d360d880f69d96fa8b426ea6bd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 20 Sep 2024 14:27:54 -0700 Subject: [PATCH] Update safety implementation inside agents --- llama_stack/apis/agents/agents.py | 2 +- .../adapters/safety/bedrock/bedrock.py | 2 +- .../meta_reference/agents/agent_instance.py | 23 +++++-------------- .../agents/tests/test_chat_agent.py | 7 +++--- .../meta_reference/safety/shields/base.py | 11 +-------- .../safety/shields/code_scanner.py | 9 +------- .../safety/shields/llama_guard.py | 19 ++++----------- .../safety/shields/prompt_guard.py | 11 --------- 8 files changed, 17 insertions(+), 67 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index bfbd5616e..e6cde46b5 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -208,7 +208,7 @@ class ToolExecutionStep(StepCommon): @json_schema_type class ShieldCallStep(StepCommon): step_type: Literal[StepType.shield_call.value] = StepType.shield_call.value - response: ShieldResponse + violation: Optional[SafetyViolation] @json_schema_type diff --git a/llama_stack/providers/adapters/safety/bedrock/bedrock.py b/llama_stack/providers/adapters/safety/bedrock/bedrock.py index f746eaa24..ee0800103 100644 --- a/llama_stack/providers/adapters/safety/bedrock/bedrock.py +++ b/llama_stack/providers/adapters/safety/bedrock/bedrock.py @@ -25,7 +25,7 @@ class BedrockSafetyAdapter(Safety): async def run_shield( self, - shield: ShieldType, + shield: str, messages: List[Message], ) -> RunShieldResponse: # clients will set api_keys by doing something like: diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index e01f5e82e..47bc74ff1 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -94,12 +94,11 @@ class ChatAgent(ShieldRunnerMixin): ) ) elif step.step_type == StepType.shield_call.value: - response = step.response - if response.is_violation: + if step.violation: # CompletionMessage itself in the ShieldResponse messages.append( CompletionMessage( - content=response.violation_return_message, + content=violation.user_message, stop_reason=StopReason.end_of_turn, ) ) @@ -276,7 +275,7 @@ class ChatAgent(ShieldRunnerMixin): step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, - response=e.response, + violation=e.violation, ), ) ) @@ -295,12 +294,7 @@ class ChatAgent(ShieldRunnerMixin): step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, - response=ShieldResponse( - # TODO: fix this, give each shield a shield type method and - # fire one event for each shield run - shield_type=BuiltinShield.llama_guard, - is_violation=False, - ), + violation=None, ), ) ) @@ -550,12 +544,7 @@ class ChatAgent(ShieldRunnerMixin): step_details=ShieldCallStep( step_id=str(uuid.uuid4()), turn_id=turn_id, - response=ShieldResponse( - # TODO: fix this, give each shield a shield type method and - # fire one event for each shield run - shield_type=BuiltinShield.llama_guard, - is_violation=False, - ), + violation=None, ), ) ) @@ -569,7 +558,7 @@ class ChatAgent(ShieldRunnerMixin): step_details=ShieldCallStep( step_id=str(uuid.uuid4()), turn_id=turn_id, - response=e.response, + violation=e.violation, ), ) ) diff --git a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py index 6cbd6000c..7a9e0beae 100644 --- a/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py +++ b/llama_stack/providers/impls/meta_reference/agents/tests/test_chat_agent.py @@ -5,7 +5,6 @@ # the root directory of this source tree. from typing import AsyncIterator, List, Optional, Union -from unittest.mock import MagicMock import pytest @@ -80,9 +79,9 @@ class MockInferenceAPI: class MockSafetyAPI: async def run_shields( - self, messages: List[Message], shields: List[MagicMock] - ) -> List[ShieldResponse]: - return [ShieldResponse(shield_type="mock_shield", is_violation=False)] + self, shield_type: str, messages: List[Message] + ) -> RunShieldResponse: + return RunShieldResponse(violation=None) class MockMemoryAPI: diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/base.py b/llama_stack/providers/impls/meta_reference/safety/shields/base.py index 86124b1e5..6a03d1e61 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/base.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/base.py @@ -35,10 +35,6 @@ class ShieldBase(ABC): ): self.on_violation_action = on_violation_action - @abstractmethod - def get_shield_type(self) -> ShieldType: - raise NotImplementedError() - @abstractmethod async def run(self, messages: List[Message]) -> ShieldResponse: raise NotImplementedError() @@ -63,11 +59,6 @@ class TextShield(ShieldBase): class DummyShield(TextShield): - def get_shield_type(self) -> ShieldType: - return "dummy" - async def run_impl(self, text: str) -> ShieldResponse: # Dummy return LOW to test e2e - return ShieldResponse( - shield_type=BuiltinShield.third_party_shield, is_violation=False - ) + return ShieldResponse(is_violation=False) diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py b/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py index 340ccb517..9b043ff04 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/code_scanner.py @@ -7,13 +7,9 @@ from termcolor import cprint from .base import ShieldResponse, TextShield -from llama_stack.apis.safety import * # noqa: F403 class CodeScannerShield(TextShield): - def get_shield_type(self) -> ShieldType: - return BuiltinShield.code_scanner_guard - async def run_impl(self, text: str) -> ShieldResponse: from codeshield.cs import CodeShield @@ -21,7 +17,6 @@ class CodeScannerShield(TextShield): result = await CodeShield.scan_code(text) if result.is_insecure: return ShieldResponse( - shield_type=BuiltinShield.code_scanner_guard, is_violation=True, violation_type=",".join( [issue.pattern_id for issue in result.issues_found] @@ -29,6 +24,4 @@ class CodeScannerShield(TextShield): violation_return_message="Sorry, I found security concerns in the code.", ) else: - return ShieldResponse( - shield_type=BuiltinShield.code_scanner_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py index c5c4f58a6..c29361b95 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/llama_guard.py @@ -14,7 +14,7 @@ from llama_models.llama3.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse -from llama_stack.apis.safety import * # noqa: F403 + SAFE_RESPONSE = "safe" _INSTANCE = None @@ -152,9 +152,6 @@ class LlamaGuardShield(ShieldBase): model_dir, torch_dtype=torch_dtype, device_map=self.device ) - def get_shield_type(self) -> ShieldType: - return BuiltinShield.llama_guard - def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) if match: @@ -192,18 +189,13 @@ class LlamaGuardShield(ShieldBase): def get_shield_response(self, response: str) -> ShieldResponse: if response == SAFE_RESPONSE: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) unsafe_code = self.check_unsafe_response(response) if unsafe_code: unsafe_code_list = unsafe_code.split(",") if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=True, violation_type=unsafe_code, violation_return_message=CANNED_RESPONSE_TEXT, @@ -213,12 +205,9 @@ class LlamaGuardShield(ShieldBase): async def run(self, messages: List[Message]) -> ShieldResponse: if self.disable_input_check and messages[-1].role == Role.user.value: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) elif self.disable_output_check and messages[-1].role == Role.assistant.value: return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False, ) else: diff --git a/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py b/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py index acaf515b5..54e911418 100644 --- a/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py +++ b/llama_stack/providers/impls/meta_reference/safety/shields/prompt_guard.py @@ -13,7 +13,6 @@ from llama_models.llama3.api.datatypes import Message from termcolor import cprint from .base import message_content_as_str, OnViolationAction, ShieldResponse, TextShield -from llama_stack.apis.safety import * # noqa: F403 class PromptGuardShield(TextShield): @@ -74,13 +73,6 @@ class PromptGuardShield(TextShield): self.threshold = threshold self.mode = mode - def get_shield_type(self) -> ShieldType: - return ( - BuiltinShield.jailbreak_shield - if self.mode == self.Mode.JAILBREAK - else BuiltinShield.injection_shield - ) - def convert_messages_to_text(self, messages: List[Message]) -> str: return message_content_as_str(messages[-1]) @@ -103,21 +95,18 @@ class PromptGuardShield(TextShield): score_embedded + score_malicious > self.threshold ): return ShieldResponse( - shield_type=self.get_shield_type(), is_violation=True, violation_type=f"prompt_injection:embedded={score_embedded},malicious={score_malicious}", violation_return_message="Sorry, I cannot do this.", ) elif self.mode == self.Mode.JAILBREAK and score_malicious > self.threshold: return ShieldResponse( - shield_type=self.get_shield_type(), is_violation=True, violation_type=f"prompt_injection:malicious={score_malicious}", violation_return_message="Sorry, I cannot do this.", ) return ShieldResponse( - shield_type=self.get_shield_type(), is_violation=False, )