From d5019cf3b3661b372f5ac27d3a465e0bebfdac28 Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 13:36:11 -0700 Subject: [PATCH 1/3] update llama guard file to latest version --- llama_toolchain/safety/shields/llama_guard.py | 50 +++++++------------ 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 94be0e06c..242b5d140 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -5,7 +5,6 @@ from typing import List, Optional import torch from llama_models.llama3_1.api.datatypes import Message -from termcolor import cprint from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -138,16 +137,14 @@ class LlamaGuardShield(ShieldBase): self.disable_input_check = disable_input_check self.disable_output_check = disable_output_check + # load model torch_dtype = torch.bfloat16 self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForCausalLM.from_pretrained( - model_dir, torch_dtype=torch_dtype, device_map=self.device + model_dir, torch_dtype=torch_dtype, device_map= self.device ) - def get_shield_type(self) -> ShieldType: - return BuiltinShield.llama_guard - def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) if match: @@ -177,52 +174,43 @@ class LlamaGuardShield(ShieldBase): categories = self.get_safety_categories() categories_str = "\n".join(categories) conversations_str = "\n\n".join( - [f"{m.role.capitalize()}: {m.content}" for m in messages] + [f"{m.role.name.capitalize()}: {m.content}" for m in messages] ) return PROMPT_TEMPLATE.substitute( - agent_type=messages[-1].role.capitalize(), + agent_type=messages[-1].role.name.capitalize(), categories=categories_str, conversations=conversations_str, ) def get_shield_response(self, response: str) -> ShieldResponse: if response == SAFE_RESPONSE: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) unsafe_code = self.check_unsafe_response(response) if unsafe_code: unsafe_code_list = unsafe_code.split(",") if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=True, violation_type=unsafe_code, violation_return_message=CANNED_RESPONSE_TEXT, + on_violation_action=OnViolationAction.RAISE, ) raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - if self.disable_input_check and messages[-1].role == Role.user.value: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) - elif self.disable_output_check and messages[-1].role == Role.assistant.value: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, - is_violation=False, - ) + if self.disable_input_check and messages[-1].role.name == "user": + return ShieldResponse(is_violation=False) + elif self.disable_output_check and messages[-1].role.name == "assistant": + return ShieldResponse(is_violation=False) else: prompt = self.build_prompt(messages) llama_guard_input = { - "role": "user", - "content": prompt, + "role": "user", + "content": prompt, } input_ids = self.tokenizer.apply_chat_template( [llama_guard_input], return_tensors="pt", tokenize=True @@ -233,16 +221,12 @@ class LlamaGuardShield(ShieldBase): max_new_tokens=20, output_scores=True, return_dict_in_generate=True, - pad_token_id=0, + pad_token_id=0 ) generated_tokens = output.sequences[:, prompt_len:] + + response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) - response = self.tokenizer.decode( - generated_tokens[0], skip_special_tokens=True - ) - - response = response.strip() shield_response = self.get_shield_response(response) - - cprint(f"Final Llama Guard response {shield_response}", color="magenta") + return shield_response From cb5829901f77cbde35cba82744f32337ece4e58b Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 13:46:43 -0700 Subject: [PATCH 2/3] redo and fix only specific lines --- llama_toolchain/safety/shields/llama_guard.py | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 242b5d140..a63d71844 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,3 +1,4 @@ + import re from string import Template @@ -137,14 +138,16 @@ class LlamaGuardShield(ShieldBase): self.disable_input_check = disable_input_check self.disable_output_check = disable_output_check - # load model torch_dtype = torch.bfloat16 self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForCausalLM.from_pretrained( - model_dir, torch_dtype=torch_dtype, device_map= self.device + model_dir, torch_dtype=torch_dtype, device_map=self.device ) + def get_shield_type(self) -> ShieldType: + return BuiltinShield.llama_guard + def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) if match: @@ -184,33 +187,42 @@ class LlamaGuardShield(ShieldBase): def get_shield_response(self, response: str) -> ShieldResponse: if response == SAFE_RESPONSE: - return ShieldResponse(is_violation=False) + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) unsafe_code = self.check_unsafe_response(response) if unsafe_code: unsafe_code_list = unsafe_code.split(",") if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse(is_violation=False) + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=True, violation_type=unsafe_code, violation_return_message=CANNED_RESPONSE_TEXT, - on_violation_action=OnViolationAction.RAISE, ) raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - if self.disable_input_check and messages[-1].role.name == "user": - return ShieldResponse(is_violation=False) - elif self.disable_output_check and messages[-1].role.name == "assistant": - return ShieldResponse(is_violation=False) + if self.disable_input_check and messages[-1].role == Role.user.value: + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) + elif self.disable_output_check and messages[-1].role == Role.assistant.value: + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, + is_violation=False, + ) else: prompt = self.build_prompt(messages) llama_guard_input = { - "role": "user", - "content": prompt, + "role": "user", + "content": prompt, } input_ids = self.tokenizer.apply_chat_template( [llama_guard_input], return_tensors="pt", tokenize=True @@ -221,12 +233,15 @@ class LlamaGuardShield(ShieldBase): max_new_tokens=20, output_scores=True, return_dict_in_generate=True, - pad_token_id=0 + pad_token_id=0, ) generated_tokens = output.sequences[:, prompt_len:] - - response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) + response = self.tokenizer.decode( + generated_tokens[0], skip_special_tokens=True + ) + + response = response.strip() shield_response = self.get_shield_response(response) - + return shield_response From 91b43600f73985146c47cd88f453525ef3e744dd Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 13:58:51 -0700 Subject: [PATCH 3/3] increase max_new_tokens --- llama_toolchain/safety/shields/llama_guard.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index a63d71844..dc7151a3e 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -230,7 +230,7 @@ class LlamaGuardShield(ShieldBase): prompt_len = input_ids.shape[1] output = self.model.generate( input_ids=input_ids, - max_new_tokens=20, + max_new_tokens=50, output_scores=True, return_dict_in_generate=True, pad_token_id=0,