chore: make cprint write to stderr (#2250)

Also do sys.exit(1) in case of errors
This commit is contained in:
raghotham 2025-05-24 23:39:57 -07:00 committed by GitHub
parent c25bd0ad58
commit 5a422e236c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 81 additions and 44 deletions

View file

@ -174,6 +174,7 @@ class Llama3:
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
file=sys.stderr,
)
prompt_tokens = [inp.tokens for inp in llm_inputs]
@ -184,7 +185,11 @@ class Llama3:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
color="red",
file=sys.stderr,
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)

View file

@ -133,9 +133,9 @@ class Llama4:
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
cprint("Input to model:\n", "yellow")
cprint("Input to model:\n", color="yellow", file=sys.stderr)
for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens), "grey")
cprint(self.tokenizer.decode(inp.tokens), color="grey", file=sys.stderr)
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = len(llm_inputs)
@ -145,7 +145,7 @@ class Llama4:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", color="red", file=sys.stderr)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)