mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
update llama guard file to latest version
This commit is contained in:
parent
6f0d348b1c
commit
d5019cf3b3
1 changed files with 17 additions and 33 deletions
|
@ -5,7 +5,6 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
@ -138,16 +137,14 @@ class LlamaGuardShield(ShieldBase):
|
|||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
|
||||
# load model
|
||||
torch_dtype = torch.bfloat16
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||
model_dir, torch_dtype=torch_dtype, device_map= self.device
|
||||
)
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return BuiltinShield.llama_guard
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
|
@ -177,46 +174,37 @@ class LlamaGuardShield(ShieldBase):
|
|||
categories = self.get_safety_categories()
|
||||
categories_str = "\n".join(categories)
|
||||
conversations_str = "\n\n".join(
|
||||
[f"{m.role.capitalize()}: {m.content}" for m in messages]
|
||||
[f"{m.role.name.capitalize()}: {m.content}" for m in messages]
|
||||
)
|
||||
return PROMPT_TEMPLATE.substitute(
|
||||
agent_type=messages[-1].role.capitalize(),
|
||||
agent_type=messages[-1].role.name.capitalize(),
|
||||
categories=categories_str,
|
||||
conversations=conversations_str,
|
||||
)
|
||||
|
||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||
if response == SAFE_RESPONSE:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
return ShieldResponse(is_violation=False)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return ShieldResponse(is_violation=False)
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=True,
|
||||
violation_type=unsafe_code,
|
||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
)
|
||||
if self.disable_input_check and messages[-1].role.name == "user":
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role.name == "assistant":
|
||||
return ShieldResponse(is_violation=False)
|
||||
else:
|
||||
|
||||
prompt = self.build_prompt(messages)
|
||||
|
@ -233,16 +221,12 @@ class LlamaGuardShield(ShieldBase):
|
|||
max_new_tokens=20,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0,
|
||||
pad_token_id=0
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
|
||||
response = self.tokenizer.decode(
|
||||
generated_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||||
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
|
||||
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
||||
return shield_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue