llama_guard inference fix

This commit is contained in:
Kate Plawiak 2024-07-22 21:26:03 -07:00
parent 8cd2e4164c
commit 138b92ae69

View file

@ -85,8 +85,8 @@ $conversations
PROMPT_INSTRUCTIONS = """ PROMPT_INSTRUCTIONS = """
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation: Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
- First line must read 'safe' or 'unsafe'. - First line must read 'safe' or 'unsafe'.
- If unsafe, a second line must include a comma-separated list of violated categories.""" - If unsafe, a second line must include a comma-separated list of violated categories."""
PROMPT_TEMPLATE = Template( PROMPT_TEMPLATE = Template(
@ -240,9 +240,39 @@ class LlamaGuardShield(ShieldBase):
response = self.tokenizer.decode( response = self.tokenizer.decode(
generated_tokens[0], skip_special_tokens=True generated_tokens[0], skip_special_tokens=True
) )
cprint(f" Llama Guard response {response}", color="magenta")
response = response.strip() response = response.strip()
shield_response = self.get_shield_response(response) shield_response = self.get_shield_response(response)
cprint(f"Final Llama Guard response {shield_response}", color="magenta") cprint(f"Final Llama Guard response {shield_response}", color="magenta")
return shield_response return shield_response
'''if self.disable_input_check and messages[-1].role == "user":
return ShieldResponse(is_violation=False)
elif self.disable_output_check and messages[-1].role == "assistant":
return ShieldResponse(is_violation=False)
else:
prompt = self.build_prompt(messages)
llama_guard_input = {
"role": "user",
"content": prompt,
}
input_ids = self.tokenizer.apply_chat_template(
[llama_guard_input], return_tensors="pt", tokenize=True
).to(self.device)
prompt_len = input_ids.shape[1]
output = self.model.generate(
input_ids=input_ids,
max_new_tokens=50,
output_scores=True,
return_dict_in_generate=True,
pad_token_id=0
)
generated_tokens = output.sequences[:, prompt_len:]
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
response = response.strip()
shield_response = self.get_shield_response(response)
return shield_response'''