update llama guard file to latest version

This commit is contained in:
Kate Plawiak 2024-07-22 13:36:11 -07:00
parent 6f0d348b1c
commit d5019cf3b3

View file

@ -5,7 +5,6 @@ from typing import List, Optional
import torch import torch
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.datatypes import Message
from termcolor import cprint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -138,6 +137,7 @@ class LlamaGuardShield(ShieldBase):
self.disable_input_check = disable_input_check self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check self.disable_output_check = disable_output_check
# load model # load model
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
@ -145,9 +145,6 @@ class LlamaGuardShield(ShieldBase):
model_dir, torch_dtype=torch_dtype, device_map= self.device model_dir, torch_dtype=torch_dtype, device_map= self.device
) )
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]: def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
if match: if match:
@ -177,46 +174,37 @@ class LlamaGuardShield(ShieldBase):
categories = self.get_safety_categories() categories = self.get_safety_categories()
categories_str = "\n".join(categories) categories_str = "\n".join(categories)
conversations_str = "\n\n".join( conversations_str = "\n\n".join(
[f"{m.role.capitalize()}: {m.content}" for m in messages] [f"{m.role.name.capitalize()}: {m.content}" for m in messages]
) )
return PROMPT_TEMPLATE.substitute( return PROMPT_TEMPLATE.substitute(
agent_type=messages[-1].role.capitalize(), agent_type=messages[-1].role.name.capitalize(),
categories=categories_str, categories=categories_str,
conversations=conversations_str, conversations=conversations_str,
) )
def get_shield_response(self, response: str) -> ShieldResponse: def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE: if response == SAFE_RESPONSE:
return ShieldResponse( return ShieldResponse(is_violation=False)
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response) unsafe_code = self.check_unsafe_response(response)
if unsafe_code: if unsafe_code:
unsafe_code_list = unsafe_code.split(",") unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)): if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True, is_violation=True,
violation_type=unsafe_code, violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT, violation_return_message=CANNED_RESPONSE_TEXT,
on_violation_action=OnViolationAction.RAISE,
) )
raise ValueError(f"Unexpected response: {response}") raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse: async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value: if self.disable_input_check and messages[-1].role.name == "user":
return ShieldResponse( return ShieldResponse(is_violation=False)
shield_type=BuiltinShield.llama_guard, is_violation=False elif self.disable_output_check and messages[-1].role.name == "assistant":
) return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else: else:
prompt = self.build_prompt(messages) prompt = self.build_prompt(messages)
@ -233,16 +221,12 @@ class LlamaGuardShield(ShieldBase):
max_new_tokens=20, max_new_tokens=20,
output_scores=True, output_scores=True,
return_dict_in_generate=True, return_dict_in_generate=True,
pad_token_id=0, pad_token_id=0
) )
generated_tokens = output.sequences[:, prompt_len:] generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode( response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response) shield_response = self.get_shield_response(response)
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
return shield_response return shield_response