Allow overridding checkpoint_dir via config

This commit is contained in:
Ashwin Bharambe 2024-10-18 14:28:06 -07:00
parent 33afd34e6f
commit 71a905e93f
2 changed files with 16 additions and 9 deletions

View file

@ -98,7 +98,10 @@ class Llama:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = model_checkpoint_dir(model)
if config.checkpoint_dir:
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -119,9 +122,7 @@ class Llama:
**params,
)
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
tokenizer = Tokenizer(model_path=tokenizer_path)
tokenizer = Tokenizer.get_instance()
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
@ -170,14 +171,16 @@ class Llama:
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
print_input_tokens: bool = False,
) -> Generator:
params = self.model.params
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
if print_input_tokens:
input_tokens = [
self.formatter.vision_token if t == 128256 else t
for t in model_input.tokens
]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1