From 138b92ae69ec2eab4a90f9f894e81b379dc38bd7 Mon Sep 17 00:00:00 2001 From: Kate Plawiak Date: Mon, 22 Jul 2024 21:26:03 -0700 Subject: [PATCH] llama_guard inference fix --- llama_toolchain/safety/shields/llama_guard.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 94be0e06c..790ff4def 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -85,8 +85,8 @@ $conversations PROMPT_INSTRUCTIONS = """ Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: -- First line must read 'safe' or 'unsafe'. -- If unsafe, a second line must include a comma-separated list of violated categories.""" + - First line must read 'safe' or 'unsafe'. + - If unsafe, a second line must include a comma-separated list of violated categories.""" PROMPT_TEMPLATE = Template( @@ -240,9 +240,39 @@ class LlamaGuardShield(ShieldBase): response = self.tokenizer.decode( generated_tokens[0], skip_special_tokens=True ) - - response = response.strip() + cprint(f" Llama Guard response {response}", color="magenta") + response = response.strip() shield_response = self.get_shield_response(response) - cprint(f"Final Llama Guard response {shield_response}", color="magenta") return shield_response + + + + + + '''if self.disable_input_check and messages[-1].role == "user": + return ShieldResponse(is_violation=False) + elif self.disable_output_check and messages[-1].role == "assistant": + return ShieldResponse(is_violation=False) + else: + prompt = self.build_prompt(messages) + llama_guard_input = { + "role": "user", + "content": prompt, + } + input_ids = self.tokenizer.apply_chat_template( + [llama_guard_input], return_tensors="pt", tokenize=True + ).to(self.device) + prompt_len = input_ids.shape[1] + output = self.model.generate( + input_ids=input_ids, + max_new_tokens=50, + output_scores=True, + return_dict_in_generate=True, + pad_token_id=0 + ) + generated_tokens = output.sequences[:, prompt_len:] + response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) + response = response.strip() + shield_response = self.get_shield_response(response) + return shield_response'''