From 5228bdc0f3de0b3ce72cc298d6a399f4c1793656 Mon Sep 17 00:00:00 2001 From: Kate Plawiak <113949869+kplawiak@users.noreply.github.com> Date: Mon, 22 Jul 2024 17:27:19 -0700 Subject: [PATCH] Revert "Update llama guard file to latest version" --- llama_toolchain/safety/shields/llama_guard.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index dc7151a3e..94be0e06c 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,4 +1,3 @@ - import re from string import Template @@ -6,6 +5,7 @@ from typing import List, Optional import torch from llama_models.llama3_1.api.datatypes import Message +from termcolor import cprint from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -177,10 +177,10 @@ class LlamaGuardShield(ShieldBase): categories = self.get_safety_categories() categories_str = "\n".join(categories) conversations_str = "\n\n".join( - [f"{m.role.name.capitalize()}: {m.content}" for m in messages] + [f"{m.role.capitalize()}: {m.content}" for m in messages] ) return PROMPT_TEMPLATE.substitute( - agent_type=messages[-1].role.name.capitalize(), + agent_type=messages[-1].role.capitalize(), categories=categories_str, conversations=conversations_str, ) @@ -230,7 +230,7 @@ class LlamaGuardShield(ShieldBase): prompt_len = input_ids.shape[1] output = self.model.generate( input_ids=input_ids, - max_new_tokens=50, + max_new_tokens=20, output_scores=True, return_dict_in_generate=True, pad_token_id=0, @@ -244,4 +244,5 @@ class LlamaGuardShield(ShieldBase): response = response.strip() shield_response = self.get_shield_response(response) + cprint(f"Final Llama Guard response {shield_response}", color="magenta") return shield_response