mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 20:14:13 +00:00
update inference config to take model and not model_dir
This commit is contained in:
parent
08c3802f45
commit
039861f1c7
9 changed files with 400 additions and 101 deletions
|
@ -75,11 +75,13 @@ safetensors files to avoid downloading duplicate weights.
|
|||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
repo_id = model.huggingface_repo
|
||||
if repo_id is None:
|
||||
raise ValueError(f"No repo id found for model {model.descriptor()}")
|
||||
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
||||
output_dir = model_local_dir(model)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
try:
|
||||
true_output_dir = snapshot_download(
|
||||
|
@ -107,8 +109,9 @@ safetensors files to avoid downloading duplicate weights.
|
|||
|
||||
def _meta_download(self, model: "Model", meta_url: str):
|
||||
from llama_models.sku_list import llama_meta_net_info
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
|
||||
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
|
||||
output_dir = model_local_dir(model)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
info = llama_meta_net_info(model)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue