Allow overridding checkpoint_dir via config

This commit is contained in:
Ashwin Bharambe 2024-10-18 14:28:06 -07:00
parent 33afd34e6f
commit 71a905e93f
2 changed files with 16 additions and 9 deletions

View file

@ -29,6 +29,10 @@ class MetaReferenceInferenceConfig(BaseModel):
# (including our testing code) who might be using llama-stack as a library. # (including our testing code) who might be using llama-stack as a library.
create_distributed_process_group: bool = True create_distributed_process_group: bool = True
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None
@field_validator("model") @field_validator("model")
@classmethod @classmethod
def validate_model(cls, model: str) -> str: def validate_model(cls, model: str) -> str:

View file

@ -98,7 +98,10 @@ class Llama:
sys.stdout = open(os.devnull, "w") sys.stdout = open(os.devnull, "w")
start_time = time.time() start_time = time.time()
ckpt_dir = model_checkpoint_dir(model) if config.checkpoint_dir:
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -119,9 +122,7 @@ class Llama:
**params, **params,
) )
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model") tokenizer = Tokenizer.get_instance()
tokenizer = Tokenizer(model_path=tokenizer_path)
assert ( assert (
model_args.vocab_size == tokenizer.n_words model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
@ -170,14 +171,16 @@ class Llama:
logprobs: bool = False, logprobs: bool = False,
echo: bool = False, echo: bool = False,
include_stop_token: bool = False, include_stop_token: bool = False,
print_input_tokens: bool = False,
) -> Generator: ) -> Generator:
params = self.model.params params = self.model.params
# input_tokens = [ if print_input_tokens:
# self.formatter.vision_token if t == 128256 else t input_tokens = [
# for t in model_input.tokens self.formatter.vision_token if t == 128256 else t
# ] for t in model_input.tokens
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red") ]
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens] prompt_tokens = [model_input.tokens]
bsz = 1 bsz = 1