mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
clean up and add license
This commit is contained in:
parent
7a8b5c1604
commit
16fe0e4594
1 changed files with 4 additions and 42 deletions
|
@ -1,4 +1,3 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
|
@ -16,8 +15,7 @@ from string import Template
|
|||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from termcolor import cprint
|
||||
from llama_models.llama3_1.api.datatypes import Message, Role
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||
|
@ -112,7 +110,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
def instance(
|
||||
on_violation_action=OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = None,
|
||||
excluded_categories: List[str] = [],
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
) -> "LlamaGuardShield":
|
||||
|
@ -131,7 +129,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
self,
|
||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||
model_dir: str = None,
|
||||
excluded_categories: List[str] = None,
|
||||
excluded_categories: List[str] = [],
|
||||
disable_input_check: bool = False,
|
||||
disable_output_check: bool = False,
|
||||
):
|
||||
|
@ -141,8 +139,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
|
||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||
|
||||
if excluded_categories is None:
|
||||
excluded_categories = []
|
||||
assert len(excluded_categories) == 0 or all(
|
||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||
|
@ -221,8 +217,7 @@ class LlamaGuardShield(ShieldBase):
|
|||
raise ValueError(f"Unexpected response: {response}")
|
||||
|
||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||
return ShieldResponse(
|
||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||
)
|
||||
|
@ -254,39 +249,6 @@ class LlamaGuardShield(ShieldBase):
|
|||
response = self.tokenizer.decode(
|
||||
generated_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
cprint(f" Llama Guard response {response}", color="magenta")
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
||||
return shield_response
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
'''if self.disable_input_check and messages[-1].role == "user":
|
||||
return ShieldResponse(is_violation=False)
|
||||
elif self.disable_output_check and messages[-1].role == "assistant":
|
||||
return ShieldResponse(is_violation=False)
|
||||
else:
|
||||
prompt = self.build_prompt(messages)
|
||||
llama_guard_input = {
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
input_ids = self.tokenizer.apply_chat_template(
|
||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
||||
).to(self.device)
|
||||
prompt_len = input_ids.shape[1]
|
||||
output = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=50,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
pad_token_id=0
|
||||
)
|
||||
generated_tokens = output.sequences[:, prompt_len:]
|
||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||||
response = response.strip()
|
||||
shield_response = self.get_shield_response(response)
|
||||
return shield_response'''
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue