update llama guard file to latest version

This commit is contained in:
Kate Plawiak 2024-07-22 13:36:11 -07:00
parent 6f0d348b1c
commit d5019cf3b3

View file

@ -5,7 +5,6 @@ from typing import List, Optional
import torch
from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -138,16 +137,14 @@ class LlamaGuardShield(ShieldBase):
self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check
# load model
torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(
model_dir, torch_dtype=torch_dtype, device_map=self.device
model_dir, torch_dtype=torch_dtype, device_map= self.device
)
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response)
if match:
@ -177,46 +174,37 @@ class LlamaGuardShield(ShieldBase):
categories = self.get_safety_categories()
categories_str = "\n".join(categories)
conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages]
[f"{m.role.name.capitalize()}: {m.content}" for m in messages]
)
return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(),
agent_type=messages[-1].role.name.capitalize(),
categories=categories_str,
conversations=conversations_str,
)
def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(is_violation=False)
unsafe_code = self.check_unsafe_response(response)
if unsafe_code:
unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True,
violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT,
on_violation_action=OnViolationAction.RAISE,
)
raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
if self.disable_input_check and messages[-1].role.name == "user":
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role.name == "assistant":
return ShieldResponse(is_violation=False)
else:
prompt = self.build_prompt(messages)
@ -233,16 +221,12 @@ class LlamaGuardShield(ShieldBase):
max_new_tokens=20,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0,
pad_token_id=0
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
response = response.strip()
shield_response = self.get_shield_response(response)
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
return shield_response