diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 94be0e06c..dc7151a3e 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,3 +1,4 @@ + import re from string import Template @@ -5,7 +6,6 @@ from typing import List, Optional import torch from llama_models.llama3_1.api.datatypes import Message -from termcolor import cprint from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -177,10 +177,10 @@ class LlamaGuardShield(ShieldBase): categories = self.get_safety_categories() categories_str = "\n".join(categories) conversations_str = "\n\n".join( - [f"{m.role.capitalize()}: {m.content}" for m in messages] + [f"{m.role.name.capitalize()}: {m.content}" for m in messages] ) return PROMPT_TEMPLATE.substitute( - agent_type=messages[-1].role.capitalize(), + agent_type=messages[-1].role.name.capitalize(), categories=categories_str, conversations=conversations_str, ) @@ -230,7 +230,7 @@ class LlamaGuardShield(ShieldBase): prompt_len = input_ids.shape[1] output = self.model.generate( input_ids=input_ids, - max_new_tokens=20, + max_new_tokens=50, output_scores=True, return_dict_in_generate=True, pad_token_id=0, @@ -244,5 +244,4 @@ class LlamaGuardShield(ShieldBase): response = response.strip() shield_response = self.get_shield_response(response) - cprint(f"Final Llama Guard response {shield_response}", color="magenta") return shield_response