From 71a905e93f06b7779d37755d0c8831513f54cb8f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Oct 2024 14:28:06 -0700 Subject: [PATCH] Allow overridding checkpoint_dir via config --- .../impls/meta_reference/inference/config.py | 4 ++++ .../meta_reference/inference/generation.py | 21 +++++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/impls/meta_reference/inference/config.py b/llama_stack/providers/impls/meta_reference/inference/config.py index 4e1161ced..48cba645b 100644 --- a/llama_stack/providers/impls/meta_reference/inference/config.py +++ b/llama_stack/providers/impls/meta_reference/inference/config.py @@ -29,6 +29,10 @@ class MetaReferenceInferenceConfig(BaseModel): # (including our testing code) who might be using llama-stack as a library. create_distributed_process_group: bool = True + # By default, the implementation will look at ~/.llama/checkpoints/ but you + # can override by specifying the directory explicitly + checkpoint_dir: Optional[str] = None + @field_validator("model") @classmethod def validate_model(cls, model: str) -> str: diff --git a/llama_stack/providers/impls/meta_reference/inference/generation.py b/llama_stack/providers/impls/meta_reference/inference/generation.py index 9037b9acd..20a8addc7 100644 --- a/llama_stack/providers/impls/meta_reference/inference/generation.py +++ b/llama_stack/providers/impls/meta_reference/inference/generation.py @@ -98,7 +98,10 @@ class Llama: sys.stdout = open(os.devnull, "w") start_time = time.time() - ckpt_dir = model_checkpoint_dir(model) + if config.checkpoint_dir: + ckpt_dir = config.checkpoint_dir + else: + ckpt_dir = model_checkpoint_dir(model) checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" @@ -119,9 +122,7 @@ class Llama: **params, ) - tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model") - tokenizer = Tokenizer(model_path=tokenizer_path) - + tokenizer = Tokenizer.get_instance() assert ( model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" @@ -170,14 +171,16 @@ class Llama: logprobs: bool = False, echo: bool = False, include_stop_token: bool = False, + print_input_tokens: bool = False, ) -> Generator: params = self.model.params - # input_tokens = [ - # self.formatter.vision_token if t == 128256 else t - # for t in model_input.tokens - # ] - # cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red") + if print_input_tokens: + input_tokens = [ + self.formatter.vision_token if t == 128256 else t + for t in model_input.tokens + ] + cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red") prompt_tokens = [model_input.tokens] bsz = 1