mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
llama_guard inference fix
This commit is contained in:
parent
8cd2e4164c
commit
138b92ae69
1 changed files with 35 additions and 5 deletions
|
@ -240,9 +240,39 @@ class LlamaGuardShield(ShieldBase):
|
||||||
response = self.tokenizer.decode(
|
response = self.tokenizer.decode(
|
||||||
generated_tokens[0], skip_special_tokens=True
|
generated_tokens[0], skip_special_tokens=True
|
||||||
)
|
)
|
||||||
|
cprint(f" Llama Guard response {response}", color="magenta")
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
shield_response = self.get_shield_response(response)
|
shield_response = self.get_shield_response(response)
|
||||||
|
|
||||||
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
'''if self.disable_input_check and messages[-1].role == "user":
|
||||||
|
return ShieldResponse(is_violation=False)
|
||||||
|
elif self.disable_output_check and messages[-1].role == "assistant":
|
||||||
|
return ShieldResponse(is_violation=False)
|
||||||
|
else:
|
||||||
|
prompt = self.build_prompt(messages)
|
||||||
|
llama_guard_input = {
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}
|
||||||
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
|
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||||
|
).to(self.device)
|
||||||
|
prompt_len = input_ids.shape[1]
|
||||||
|
output = self.model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
max_new_tokens=50,
|
||||||
|
output_scores=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
pad_token_id=0
|
||||||
|
)
|
||||||
|
generated_tokens = output.sequences[:, prompt_len:]
|
||||||
|
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||||||
|
response = response.strip()
|
||||||
|
shield_response = self.get_shield_response(response)
|
||||||
|
return shield_response'''
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue