forked from phoenix-oss/llama-stack-mirror
Allow overridding checkpoint_dir via config
This commit is contained in:
parent
33afd34e6f
commit
71a905e93f
2 changed files with 16 additions and 9 deletions
|
@ -29,6 +29,10 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
# (including our testing code) who might be using llama-stack as a library.
|
||||
create_distributed_process_group: bool = True
|
||||
|
||||
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
|
|
|
@ -98,7 +98,10 @@ class Llama:
|
|||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
if config.checkpoint_dir:
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
|
@ -119,9 +122,7 @@ class Llama:
|
|||
**params,
|
||||
)
|
||||
|
||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
@ -170,14 +171,16 @@ class Llama:
|
|||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
include_stop_token: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
# input_tokens = [
|
||||
# self.formatter.vision_token if t == 128256 else t
|
||||
# for t in model_input.tokens
|
||||
# ]
|
||||
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
if print_input_tokens:
|
||||
input_tokens = [
|
||||
self.formatter.vision_token if t == 128256 else t
|
||||
for t in model_input.tokens
|
||||
]
|
||||
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue