mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
llama_guard inference fix
This commit is contained in:
parent
8cd2e4164c
commit
138b92ae69
1 changed files with 35 additions and 5 deletions
|
@ -85,8 +85,8 @@ $conversations
|
|||
|
||||
PROMPT_INSTRUCTIONS = """
|
||||
Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
|
||||
- First line must read 'safe' or 'unsafe'.
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||
- First line must read 'safe' or 'unsafe'.
|
||||
- If unsafe, a second line must include a comma-separated list of violated categories."""
|
||||
|
||||
|
||||
PROMPT_TEMPLATE = Template(
|
||||
|
@ -240,9 +240,39 @@ class LlamaGuardShield(ShieldBase):
|
|||
response = self.tokenizer.decode(
|
||||
generated_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
|
||||
response = response.strip()
|
||||
cprint(f" Llama Guard response {response}", color="magenta")
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
|
||||
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
||||
return shield_response
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
'''if self.disable_input_check and messages[-1].role == "user":
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == "assistant":
|
||||
return ShieldResponse(is_violation=False)
|
||||
else:
|
||||
prompt = self.build_prompt(messages)
|
||||
llama_guard_input = {
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||
).to(self.device)
|
||||
prompt_len = input_ids.shape[1]
|
||||
output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=50,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
return shield_response'''
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue