redo and fix only specific lines

This commit is contained in:
Kate Plawiak 2024-07-22 13:46:43 -07:00
parent d5019cf3b3
commit cb5829901f

View file

@ -1,3 +1,4 @@
import re import re
from string import Template from string import Template
@ -137,7 +138,6 @@ class LlamaGuardShield(ShieldBase):
self.disable_input_check = disable_input_check self.disable_input_check = disable_input_check
self.disable_output_check = disable_output_check self.disable_output_check = disable_output_check
# load model # load model
torch_dtype = torch.bfloat16 torch_dtype = torch.bfloat16
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
@ -145,6 +145,9 @@ class LlamaGuardShield(ShieldBase):
model_dir, torch_dtype=torch_dtype, device_map=self.device model_dir, torch_dtype=torch_dtype, device_map=self.device
) )
def get_shield_type(self) -> ShieldType:
return BuiltinShield.llama_guard
def check_unsafe_response(self, response: str) -> Optional[str]: def check_unsafe_response(self, response: str) -> Optional[str]:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
if match: if match:
@ -184,27 +187,36 @@ class LlamaGuardShield(ShieldBase):
def get_shield_response(self, response: str) -> ShieldResponse: def get_shield_response(self, response: str) -> ShieldResponse:
if response == SAFE_RESPONSE: if response == SAFE_RESPONSE:
return ShieldResponse(is_violation=False) return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
unsafe_code = self.check_unsafe_response(response) unsafe_code = self.check_unsafe_response(response)
if unsafe_code: if unsafe_code:
unsafe_code_list = unsafe_code.split(",") unsafe_code_list = unsafe_code.split(",")
if set(unsafe_code_list).issubset(set(self.excluded_categories)): if set(unsafe_code_list).issubset(set(self.excluded_categories)):
return ShieldResponse(is_violation=False)
return ShieldResponse( return ShieldResponse(
shield_type=BuiltinShield.llama_guard, is_violation=False
)
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=True, is_violation=True,
violation_type=unsafe_code, violation_type=unsafe_code,
violation_return_message=CANNED_RESPONSE_TEXT, violation_return_message=CANNED_RESPONSE_TEXT,
on_violation_action=OnViolationAction.RAISE,
) )
raise ValueError(f"Unexpected response: {response}") raise ValueError(f"Unexpected response: {response}")
async def run(self, messages: List[Message]) -> ShieldResponse: async def run(self, messages: List[Message]) -> ShieldResponse:
if self.disable_input_check and messages[-1].role.name == "user": if self.disable_input_check and messages[-1].role == Role.user.value:
return ShieldResponse(is_violation=False) return ShieldResponse(
elif self.disable_output_check and messages[-1].role.name == "assistant": shield_type=BuiltinShield.llama_guard, is_violation=False
return ShieldResponse(is_violation=False) )
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
return ShieldResponse(
shield_type=BuiltinShield.llama_guard,
is_violation=False,
)
else: else:
prompt = self.build_prompt(messages) prompt = self.build_prompt(messages)
@ -221,12 +233,15 @@ class LlamaGuardShield(ShieldBase):
max_new_tokens=20, max_new_tokens=20,
output_scores=True, output_scores=True,
return_dict_in_generate=True, return_dict_in_generate=True,
pad_token_id=0 pad_token_id=0,
) )
generated_tokens = output.sequences[:, prompt_len:] generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True
)
response = response.strip()
shield_response = self.get_shield_response(response) shield_response = self.get_shield_response(response)
return shield_response return shield_response