diff --git a/llama_toolchain/safety/shields/llama_guard.py b/llama_toolchain/safety/shields/llama_guard.py index a63d71844..dc7151a3e 100644 --- a/llama_toolchain/safety/shields/llama_guard.py +++ b/llama_toolchain/safety/shields/llama_guard.py @@ -230,7 +230,7 @@ class LlamaGuardShield(ShieldBase): prompt_len = input_ids.shape[1] output = self.model.generate( input_ids=input_ids, - max_new_tokens=20, + max_new_tokens=50, output_scores=True, return_dict_in_generate=True, pad_token_id=0,