Revert "Update llama guard file to latest version"

This commit is contained in:
Kate Plawiak 2024-07-22 17:27:19 -07:00 committed by GitHub
parent dfe0173b58
commit 5228bdc0f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,4 +1,3 @@
import re import re
from string import Template from string import Template
@ -6,6 +5,7 @@ from typing import List, Optional
import torch import torch
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -177,10 +177,10 @@ class LlamaGuardShield(ShieldBase):
categories = self.get_safety_categories() categories = self.get_safety_categories()
categories_str = "\n".join(categories) categories_str = "\n".join(categories)
conversations_str = "\n\n".join( conversations_str = "\n\n".join(
[f"{m.role.name.capitalize()}: {m.content}" for m in messages] [f"{m.role.capitalize()}: {m.content}" for m in messages]
) )
return PROMPT_TEMPLATE.substitute( return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.name.capitalize(), agent_type=messages[-1].role.capitalize(),
categories=categories_str, categories=categories_str,
conversations=conversations_str, conversations=conversations_str,
) )
@ -230,7 +230,7 @@ class LlamaGuardShield(ShieldBase):
prompt_len = input_ids.shape[1] prompt_len = input_ids.shape[1]
output = self.model.generate( output = self.model.generate(
input_ids=input_ids, input_ids=input_ids,
max_new_tokens=50, max_new_tokens=20,
output_scores=True, output_scores=True,
return_dict_in_generate=True, return_dict_in_generate=True,
pad_token_id=0, pad_token_id=0,
@ -244,4 +244,5 @@ class LlamaGuardShield(ShieldBase):
response = response.strip() response = response.strip()
shield_response = self.get_shield_response(response) shield_response = self.get_shield_response(response)
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
return shield_response return shield_response