clean up and add license

This commit is contained in:
Kate Plawiak 2024-07-22 21:59:57 -07:00
parent 7a8b5c1604
commit 16fe0e4594

View file

@ -1,4 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
@ -16,8 +15,7 @@ from string import Template
from typing import List, Optional from typing import List, Optional
import torch import torch
from llama_models.llama3_1.api.datatypes import Message from llama_models.llama3_1.api.datatypes import Message, Role
from termcolor import cprint
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
@ -112,7 +110,7 @@ class LlamaGuardShield(ShieldBase):
def instance( def instance(
on_violation_action=OnViolationAction.RAISE, on_violation_action=OnViolationAction.RAISE,
model_dir: str = None, model_dir: str = None,
excluded_categories: List[str] = None, excluded_categories: List[str] = [],
disable_input_check: bool = False, disable_input_check: bool = False,
disable_output_check: bool = False, disable_output_check: bool = False,
) -> "LlamaGuardShield": ) -> "LlamaGuardShield":
@ -131,7 +129,7 @@ class LlamaGuardShield(ShieldBase):
self, self,
on_violation_action: OnViolationAction = OnViolationAction.RAISE, on_violation_action: OnViolationAction = OnViolationAction.RAISE,
model_dir: str = None, model_dir: str = None,
excluded_categories: List[str] = None, excluded_categories: List[str] = [],
disable_input_check: bool = False, disable_input_check: bool = False,
disable_output_check: bool = False, disable_output_check: bool = False,
): ):
@ -141,8 +139,6 @@ class LlamaGuardShield(ShieldBase):
assert model_dir is not None, "Llama Guard model_dir is None" assert model_dir is not None, "Llama Guard model_dir is None"
if excluded_categories is None:
excluded_categories = []
assert len(excluded_categories) == 0 or all( assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
@ -221,7 +217,6 @@ class LlamaGuardShield(ShieldBase):
raise ValueError(f"Unexpected response: {response}") raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse: async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role == Role.user.value: if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False shield_type=BuiltinShield.llama_guard, is_violation=False
@ -254,39 +249,6 @@ class LlamaGuardShield(ShieldBase):
response = self.tokenizer.decode( response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True generated_tokens[0], skip_special_tokens=True
) )
cprint(f" Llama Guard response {response}", color="magenta")
response = response.strip() response = response.strip()
shield_response = self.get_shield_response(response) shield_response = self.get_shield_response(response)
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
return shield_response return shield_response
'''if self.disable_input_check and messages[-1].role == "user":
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == "assistant":
return ShieldResponse(is_violation=False)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=50,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response'''