update inference config to take model and not model_dir

This commit is contained in:
Hardik Shah 2024-08-06 15:02:41 -07:00
parent 08c3802f45
commit 039861f1c7
9 changed files with 400 additions and 101 deletions

View file

@ -75,11 +75,13 @@ safetensors files to avoid downloading duplicate weights.
from huggingface_hub import snapshot_download
from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
from llama_toolchain.common.model_utils import model_local_dir
repo_id = model.huggingface_repo
if repo_id is None:
raise ValueError(f"No repo id found for model {model.descriptor()}")
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
output_dir = model_local_dir(model)
os.makedirs(output_dir, exist_ok=True)
try:
true_output_dir = snapshot_download(
@ -107,8 +109,9 @@ safetensors files to avoid downloading duplicate weights.
def _meta_download(self, model: "Model", meta_url: str):
from llama_models.sku_list import llama_meta_net_info
from llama_toolchain.common.model_utils import model_local_dir
output_dir = Path(DEFAULT_CHECKPOINT_DIR) / model.descriptor()
output_dir = model_local_dir(model)
os.makedirs(output_dir, exist_ok=True)
info = llama_meta_net_info(model)