mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-28 15:02:37 +00:00
redo and fix only specific lines
This commit is contained in:
parent
d5019cf3b3
commit
cb5829901f
1 changed files with 30 additions and 15 deletions
|
@ -1,3 +1,4 @@
|
|||
|
||||
import re
|
||||
|
||||
from string import Template
|
||||
|
@ -137,14 +138,16 @@ class LlamaGuardShield(ShieldBase):
|
|||
self.disable_input_check = disable_input_check
|
||||
self.disable_output_check = disable_output_check
|
||||
|
||||
|
||||
# load model
|
||||
torch_dtype = torch.bfloat16
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_dir, torch_dtype=torch_dtype, device_map= self.device
|
||||
model_dir, torch_dtype=torch_dtype, device_map=self.device
|
||||
)
|
||||
|
||||
def get_shield_type(self) -> ShieldType:
|
||||
return BuiltinShield.llama_guard
|
||||
|
||||
def check_unsafe_response(self, response: str) -> Optional[str]:
|
||||
match = re.match(r"^unsafe\n(.*)$", response)
|
||||
if match:
|
||||
|
@ -184,33 +187,42 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
def get_shield_response(self, response: str) -> ShieldResponse:
|
||||
if response == SAFE_RESPONSE:
|
||||
return ShieldResponse(is_violation=False)
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
unsafe_code = self.check_unsafe_response(response)
|
||||
if unsafe_code:
|
||||
unsafe_code_list = unsafe_code.split(",")
|
||||
if set(unsafe_code_list).issubset(set(self.excluded_categories)):
|
||||
return ShieldResponse(is_violation=False)
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=True,
|
||||
violation_type=unsafe_code,
|
||||
violation_return_message=CANNED_RESPONSE_TEXT,
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
|
||||
if self.disable_input_check and messages[-1].role.name == "user":
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role.name == "assistant":
|
||||
return ShieldResponse(is_violation=False)
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
elif self.disable_output_check and messages[-1].role == Role.assistant.value:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard,
|
||||
is_violation=False,
|
||||
)
|
||||
else:
|
||||
|
||||
prompt = self.build_prompt(messages)
|
||||
llama_guard_input = {
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||
|
@ -221,12 +233,15 @@ class LlamaGuardShield(ShieldBase):
|
|||
max_new_tokens=20,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0
|
||||
pad_token_id=0,
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
|
||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||||
|
||||
response = self.tokenizer.decode(
|
||||
generated_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
|
||||
|
||||
return shield_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue