mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
Merge pull request #2 from meta-llama/revert-1-fix_llama_guard
Revert "Update llama guard file to latest version"
This commit is contained in:
commit
8cd2e4164c
1 changed files with 5 additions and 4 deletions
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from string import Template
|
from string import Template
|
||||||
|
@ -6,6 +5,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message
|
||||||
|
from termcolor import cprint
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
@ -177,10 +177,10 @@ class LlamaGuardShield(ShieldBase):
|
||||||
categories = self.get_safety_categories()
|
categories = self.get_safety_categories()
|
||||||
categories_str = "\n".join(categories)
|
categories_str = "\n".join(categories)
|
||||||
conversations_str = "\n\n".join(
|
conversations_str = "\n\n".join(
|
||||||
[f"{m.role.name.capitalize()}: {m.content}" for m in messages]
|
[f"{m.role.capitalize()}: {m.content}" for m in messages]
|
||||||
)
|
)
|
||||||
return PROMPT_TEMPLATE.substitute(
|
return PROMPT_TEMPLATE.substitute(
|
||||||
agent_type=messages[-1].role.name.capitalize(),
|
agent_type=messages[-1].role.capitalize(),
|
||||||
categories=categories_str,
|
categories=categories_str,
|
||||||
conversations=conversations_str,
|
conversations=conversations_str,
|
||||||
)
|
)
|
||||||
|
@ -230,7 +230,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
prompt_len = input_ids.shape[1]
|
prompt_len = input_ids.shape[1]
|
||||||
output = self.model.generate(
|
output = self.model.generate(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_new_tokens=50,
|
max_new_tokens=20,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
|
@ -244,4 +244,5 @@ class LlamaGuardShield(ShieldBase):
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
shield_response = self.get_shield_response(response)
|
shield_response = self.get_shield_response(response)
|
||||||
|
|
||||||
|
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue