forked from phoenix-oss/llama-stack-mirror
Allow overridding checkpoint_dir via config
This commit is contained in:
parent
33afd34e6f
commit
71a905e93f
2 changed files with 16 additions and 9 deletions
|
@ -29,6 +29,10 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
# (including our testing code) who might be using llama-stack as a library.
|
# (including our testing code) who might be using llama-stack as a library.
|
||||||
create_distributed_process_group: bool = True
|
create_distributed_process_group: bool = True
|
||||||
|
|
||||||
|
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
|
||||||
|
# can override by specifying the directory explicitly
|
||||||
|
checkpoint_dir: Optional[str] = None
|
||||||
|
|
||||||
@field_validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
|
|
|
@ -98,7 +98,10 @@ class Llama:
|
||||||
sys.stdout = open(os.devnull, "w")
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
ckpt_dir = model_checkpoint_dir(model)
|
if config.checkpoint_dir:
|
||||||
|
ckpt_dir = config.checkpoint_dir
|
||||||
|
else:
|
||||||
|
ckpt_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
@ -119,9 +122,7 @@ class Llama:
|
||||||
**params,
|
**params,
|
||||||
)
|
)
|
||||||
|
|
||||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
tokenizer = Tokenizer.get_instance()
|
||||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
model_args.vocab_size == tokenizer.n_words
|
model_args.vocab_size == tokenizer.n_words
|
||||||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||||
|
@ -170,14 +171,16 @@ class Llama:
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
include_stop_token: bool = False,
|
include_stop_token: bool = False,
|
||||||
|
print_input_tokens: bool = False,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
params = self.model.params
|
params = self.model.params
|
||||||
|
|
||||||
# input_tokens = [
|
if print_input_tokens:
|
||||||
# self.formatter.vision_token if t == 128256 else t
|
input_tokens = [
|
||||||
# for t in model_input.tokens
|
self.formatter.vision_token if t == 128256 else t
|
||||||
# ]
|
for t in model_input.tokens
|
||||||
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
]
|
||||||
|
cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
|
||||||
prompt_tokens = [model_input.tokens]
|
prompt_tokens = [model_input.tokens]
|
||||||
|
|
||||||
bsz = 1
|
bsz = 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue