diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 242b5d140..a63d71844 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,3 +1,4 @@ + import re from string import Template @@ -137,14 +138,16 @@ class LlamaGuardShield(ShieldBase): self.disable_input_check = disable_input_check self.disable_output_check = disable_output_check - # load model torch_dtype = torch.bfloat16 self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForCausalLM.from_pretrained( - model_dir, torch_dtype=torch_dtype, device_map= self.device + model_dir, torch_dtype=torch_dtype, device_map=self.device ) + def get_shield_type(self) -> ShieldType: + return BuiltinShield.llama_guard + def check_unsafe_response(self, response: str) -> Optional[str]: match = re.match(r"^unsafe\n(.*)$", response) if match: @@ -184,33 +187,42 @@ class LlamaGuardShield(ShieldBase): def get_shield_response(self, response: str) -> ShieldResponse: if response == SAFE_RESPONSE: - return ShieldResponse(is_violation=False) + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) unsafe_code = self.check_unsafe_response(response) if unsafe_code: unsafe_code_list = unsafe_code.split(",") if set(unsafe_code_list).issubset(set(self.excluded_categories)): - return ShieldResponse(is_violation=False) + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=True, violation_type=unsafe_code, violation_return_message=CANNED_RESPONSE_TEXT, - on_violation_action=OnViolationAction.RAISE, ) raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - if self.disable_input_check and messages[-1].role.name == "user": - return ShieldResponse(is_violation=False) - elif self.disable_output_check and messages[-1].role.name == "assistant": - return ShieldResponse(is_violation=False) + if self.disable_input_check and messages[-1].role == Role.user.value: + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, is_violation=False + ) + elif self.disable_output_check and messages[-1].role == Role.assistant.value: + return ShieldResponse( + shield_type=BuiltinShield.llama_guard, + is_violation=False, + ) else: prompt = self.build_prompt(messages) llama_guard_input = { - "role": "user", - "content": prompt, + "role": "user", + "content": prompt, } input_ids = self.tokenizer.apply_chat_template( [llama_guard_input], return_tensors="pt", tokenize=True @@ -221,12 +233,15 @@ class LlamaGuardShield(ShieldBase): max_new_tokens=20, output_scores=True, return_dict_in_generate=True, - pad_token_id=0 + pad_token_id=0, ) generated_tokens = output.sequences[:, prompt_len:] - - response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) + response = self.tokenizer.decode( + generated_tokens[0], skip_special_tokens=True + ) + + response = response.strip() shield_response = self.get_shield_response(response) - + return shield_response