mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
redo and fix only specific lines
This commit is contained in:
parent
d5019cf3b3
commit
cb5829901f
1 changed files with 30 additions and 15 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from string import Template
|
from string import Template
|
||||||
|
@ -137,14 +138,16 @@ class LlamaGuardShield(ShieldBase):
|
||||||
self.disable_input_check = disable_input_check
|
self.disable_input_check = disable_input_check
|
||||||
self.disable_output_check = disable_output_check
|
self.disable_output_check = disable_output_check
|
||||||
|
|
||||||
|
|
||||||
# load model
|
# load model
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_dir, torch_dtype=torch_dtype, device_map= self.device
|
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_shield_type(self) -> ShieldType:
|
||||||
|
return BuiltinShield.llama_guard
|
||||||
|
|
||||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||||
match = re.match(r"^unsafe\n(.*)$", response)
|
match = re.match(r"^unsafe\n(.*)$", response)
|
||||||
if match:
|
if match:
|
||||||
|
@ -184,33 +187,42 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||||
if response == SAFE_RESPONSE:
|
if response == SAFE_RESPONSE:
|
||||||
return ShieldResponse(is_violation=False)
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||||
|
)
|
||||||
unsafe_code = self.check_unsafe_response(response)
|
unsafe_code = self.check_unsafe_response(response)
|
||||||
if unsafe_code:
|
if unsafe_code:
|
||||||
unsafe_code_list = unsafe_code.split(",")
|
unsafe_code_list = unsafe_code.split(",")
|
||||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||||
return ShieldResponse(is_violation=False)
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||||
|
)
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard,
|
||||||
is_violation=True,
|
is_violation=True,
|
||||||
violation_type=unsafe_code,
|
violation_type=unsafe_code,
|
||||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||||
on_violation_action=OnViolationAction.RAISE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
raise ValueError(f"Unexpected response: {response}")
|
raise ValueError(f"Unexpected response: {response}")
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
|
|
||||||
if self.disable_input_check and messages[-1].role.name == "user":
|
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||||
return ShieldResponse(is_violation=False)
|
return ShieldResponse(
|
||||||
elif self.disable_output_check and messages[-1].role.name == "assistant":
|
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||||
return ShieldResponse(is_violation=False)
|
)
|
||||||
|
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||||
|
return ShieldResponse(
|
||||||
|
shield_type=BuiltinShield.llama_guard,
|
||||||
|
is_violation=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
prompt = self.build_prompt(messages)
|
prompt = self.build_prompt(messages)
|
||||||
llama_guard_input = {
|
llama_guard_input = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": prompt,
|
"content": prompt,
|
||||||
}
|
}
|
||||||
input_ids = self.tokenizer.apply_chat_template(
|
input_ids = self.tokenizer.apply_chat_template(
|
||||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||||
|
@ -221,12 +233,15 @@ class LlamaGuardShield(ShieldBase):
|
||||||
max_new_tokens=20,
|
max_new_tokens=20,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
pad_token_id=0
|
pad_token_id=0,
|
||||||
)
|
)
|
||||||
generated_tokens = output.sequences[:, prompt_len:]
|
generated_tokens = output.sequences[:, prompt_len:]
|
||||||
|
|
||||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
|
response = self.tokenizer.decode(
|
||||||
|
generated_tokens[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
response = response.strip()
|
||||||
shield_response = self.get_shield_response(response)
|
shield_response = self.get_shield_response(response)
|
||||||
|
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue