diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index 94f111fb7..12aadc305 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -1,4 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the terms described in the LICENSE file in @@ -16,8 +15,7 @@ from string import Template from typing import List, Optional import torch -from llama_models.llama3_1.api.datatypes import Message -from termcolor import cprint +from llama_models.llama3_1.api.datatypes import Message, Role from transformers import AutoModelForCausalLM, AutoTokenizer from .base import CANNED_RESPONSE_TEXT, OnViolationAction, ShieldBase, ShieldResponse @@ -112,7 +110,7 @@ class LlamaGuardShield(ShieldBase): def instance( on_violation_action=OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = None, + excluded_categories: List[str] = [], disable_input_check: bool = False, disable_output_check: bool = False, ) -> "LlamaGuardShield": @@ -131,7 +129,7 @@ class LlamaGuardShield(ShieldBase): self, on_violation_action: OnViolationAction = OnViolationAction.RAISE, model_dir: str = None, - excluded_categories: List[str] = None, + excluded_categories: List[str] = [], disable_input_check: bool = False, disable_output_check: bool = False, ): @@ -141,8 +139,6 @@ class LlamaGuardShield(ShieldBase): assert model_dir is not None, "Llama Guard model_dir is None" - if excluded_categories is None: - excluded_categories = [] assert len(excluded_categories) == 0 or all( x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" @@ -221,8 +217,7 @@ class LlamaGuardShield(ShieldBase): raise ValueError(f"Unexpected response: {response}") async def run(self, messages: List[Message]) -> ShieldResponse: - - if self.disable_input_check and messages[-1].role == Role.user.value: + if self.disable_input_check and messages[-1].role == Role.user.value: return ShieldResponse( shield_type=BuiltinShield.llama_guard, is_violation=False ) @@ -254,39 +249,6 @@ class LlamaGuardShield(ShieldBase): response = self.tokenizer.decode( generated_tokens[0], skip_special_tokens=True ) - cprint(f" Llama Guard response {response}", color="magenta") response = response.strip() shield_response = self.get_shield_response(response) - cprint(f"Final Llama Guard response {shield_response}", color="magenta") return shield_response - - - - - - '''if self.disable_input_check and messages[-1].role == "user": - return ShieldResponse(is_violation=False) - elif self.disable_output_check and messages[-1].role == "assistant": - return ShieldResponse(is_violation=False) - else: - prompt = self.build_prompt(messages) - llama_guard_input = { - "role": "user", - "content": prompt, - } - input_ids = self.tokenizer.apply_chat_template( - [llama_guard_input], return_tensors="pt", tokenize=True - ).to(self.device) - prompt_len = input_ids.shape[1] - output = self.model.generate( - input_ids=input_ids, - max_new_tokens=50, - output_scores=True, - return_dict_in_generate=True, - pad_token_id=0 - ) - generated_tokens = output.sequences[:, prompt_len:] - response = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True) - response = response.strip() - shield_response = self.get_shield_response(response) - return shield_response'''