From d5019cf3b3661b372f5ac27d3a465e0bebfdac28 Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 13:36:11 -0700 Subject: [PATCH] update llama guard file to latest version --- llama_toolchain/safety/shields/llama_guard.py | 50 +++++++------------ 1 file changed, 17 insertions(+), 33 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 94be0e06c..242b5d140 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -5,7 +5,6 @@ from typing import List, Optional import torch from llama_models.llama3_1.api.datatypes import Message -from termcolor import cprint from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -138,16 +137,14 @@ class LlamaGuardShield(ShieldBase): self.disable_input_check = disable_input_check self.disable_output_check = disable_output_check + # load model torch_dtype = torch.bfloat16 self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForCausalLM.from_pretrained( - model_dir, torch_dtype=torch_dtype, device_map=self.device + model_dir, torch_dtype=torch_dtype, device_map= self.device ) - def get_shield_type(self) -> ShieldType: - return BuiltinShield.llama_guard - def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) if match: @@ -177,52 +174,43 @@ class LlamaGuardShield(ShieldBase): categories = self.get_safety_categories() categories_str = "\n".join(categories) conversations_str = "\n\n".join( - [f"{m.role.capitalize()}: {m.content}" for m in messages] + [f"{m.role.name.capitalize()}: {m.content}" for m in messages] ) return PROMPT_TEMPLATE.substitute( - agent_type=messages[-1].role.capitalize(), + agent_type=messages[-1].role.name.capitalize(), categories=categories_str, conversations=conversations_str, ) def get_shield_response(self, response: str) -> ShieldResponse: if response == SAFE_RESPONSE: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) unsafe_code = self.check_unsafe_response(response) if unsafe_code: unsafe_code_list = unsafe_code.split(",") if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) + return ShieldResponse(is_violation=False) return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=True, violation_type=unsafe_code, violation_return_message=CANNED_RESPONSE_TEXT, + on_violation_action=OnViolationAction.RAISE, ) raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - if self.disable_input_check and messages[-1].role == Role.user.value: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, is_violation=False - ) - elif self.disable_output_check and messages[-1].role == Role.assistant.value: - return ShieldResponse( - shield_type=BuiltinShield.llama_guard, - is_violation=False, - ) + if self.disable_input_check and messages[-1].role.name == "user": + return ShieldResponse(is_violation=False) + elif self.disable_output_check and messages[-1].role.name == "assistant": + return ShieldResponse(is_violation=False) else: prompt = self.build_prompt(messages) llama_guard_input = { - "role": "user", - "content": prompt, + "role": "user", + "content": prompt, } input_ids = self.tokenizer.apply_chat_template( [llama_guard_input], return_tensors="pt", tokenize=True @@ -233,16 +221,12 @@ class LlamaGuardShield(ShieldBase): max_new_tokens=20, output_scores=True, return_dict_in_generate=True, - pad_token_id=0, + pad_token_id=0 ) generated_tokens = output.sequences[:, prompt_len:] + + response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) - response = self.tokenizer.decode( - generated_tokens[0], skip_special_tokens=True - ) - - response = response.strip() shield_response = self.get_shield_response(response) - - cprint(f"Final Llama Guard response {shield_response}", color="magenta") + return shield_response