mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 15:23:51 +00:00
clean up and add license
This commit is contained in:
parent
7a8b5c1604
commit
16fe0e4594
1 changed files with 4 additions and 42 deletions
|
@ -1,4 +1,3 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
@ -16,8 +15,7 @@ from string import Template
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.llama3_1.api.datatypes import Message
|
from llama_models.llama3_1.api.datatypes import Message, Role
|
||||||
from termcolor import cprint
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse
|
||||||
|
@ -112,7 +110,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
def instance(
|
def instance(
|
||||||
on_violation_action=OnViolationAction.RAISE,
|
on_violation_action=OnViolationAction.RAISE,
|
||||||
model_dir: str = None,
|
model_dir: str = None,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = [],
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
) -> "LlamaGuardShield":
|
) -> "LlamaGuardShield":
|
||||||
|
@ -131,7 +129,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
self,
|
self,
|
||||||
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
on_violation_action: OnViolationAction = OnViolationAction.RAISE,
|
||||||
model_dir: str = None,
|
model_dir: str = None,
|
||||||
excluded_categories: List[str] = None,
|
excluded_categories: List[str] = [],
|
||||||
disable_input_check: bool = False,
|
disable_input_check: bool = False,
|
||||||
disable_output_check: bool = False,
|
disable_output_check: bool = False,
|
||||||
):
|
):
|
||||||
|
@ -141,8 +139,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
|
|
||||||
assert model_dir is not None, "Llama Guard model_dir is None"
|
assert model_dir is not None, "Llama Guard model_dir is None"
|
||||||
|
|
||||||
if excluded_categories is None:
|
|
||||||
excluded_categories = []
|
|
||||||
assert len(excluded_categories) == 0 or all(
|
assert len(excluded_categories) == 0 or all(
|
||||||
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
|
||||||
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
|
||||||
|
@ -221,8 +217,7 @@ class LlamaGuardShield(ShieldBase):
|
||||||
raise ValueError(f"Unexpected response: {response}")
|
raise ValueError(f"Unexpected response: {response}")
|
||||||
|
|
||||||
async def run(self, messages: List[Message]) -> ShieldResponse:
|
async def run(self, messages: List[Message]) -> ShieldResponse:
|
||||||
|
if self.disable_input_check and messages[-1].role == Role.user.value:
|
||||||
if self.disable_input_check and messages[-1].role == Role.user.value:
|
|
||||||
return ShieldResponse(
|
return ShieldResponse(
|
||||||
shield_type=BuiltinShield.llama_guard, is_violation=False
|
shield_type=BuiltinShield.llama_guard, is_violation=False
|
||||||
)
|
)
|
||||||
|
@ -254,39 +249,6 @@ class LlamaGuardShield(ShieldBase):
|
||||||
response = self.tokenizer.decode(
|
response = self.tokenizer.decode(
|
||||||
generated_tokens[0], skip_special_tokens=True
|
generated_tokens[0], skip_special_tokens=True
|
||||||
)
|
)
|
||||||
cprint(f" Llama Guard response {response}", color="magenta")
|
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
shield_response = self.get_shield_response(response)
|
shield_response = self.get_shield_response(response)
|
||||||
cprint(f"Final Llama Guard response {shield_response}", color="magenta")
|
|
||||||
return shield_response
|
return shield_response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
'''if self.disable_input_check and messages[-1].role == "user":
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
||||||
elif self.disable_output_check and messages[-1].role == "assistant":
|
|
||||||
return ShieldResponse(is_violation=False)
|
|
||||||
else:
|
|
||||||
prompt = self.build_prompt(messages)
|
|
||||||
llama_guard_input = {
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
}
|
|
||||||
input_ids = self.tokenizer.apply_chat_template(
|
|
||||||
[llama_guard_input], return_tensors="pt", tokenize=True
|
|
||||||
).to(self.device)
|
|
||||||
prompt_len = input_ids.shape[1]
|
|
||||||
output = self.model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=50,
|
|
||||||
output_scores=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
pad_token_id=0
|
|
||||||
)
|
|
||||||
generated_tokens = output.sequences[:, prompt_len:]
|
|
||||||
response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
|
||||||
response = response.strip()
|
|
||||||
shield_response = self.get_shield_response(response)
|
|
||||||
return shield_response'''
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue