fix: make cprint write to stderr

This commit is contained in:
Raghotham Murthy 2025-05-24 21:46:40 -07:00
parent c290999c63
commit 8658109454
11 changed files with 81 additions and 44 deletions

View file

@ -174,6 +174,7 @@ class Llama3:
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
file=sys.stderr,
)
prompt_tokens = [inp.tokens for inp in llm_inputs]
@ -184,7 +185,11 @@ class Llama3:
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}",
color="red",
file=sys.stderr,
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)