mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 23:39:48 +00:00
refine
This commit is contained in:
parent
9e5b7d5c9e
commit
7ab807ad76
3 changed files with 17 additions and 2 deletions
|
|
@ -31,6 +31,7 @@ from llama_models.llama3.reference_impl.model import Transformer
|
|||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
|
@ -113,7 +114,13 @@ class Llama:
|
|||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue