From 4d3b226275ee6debbf1c33f77dc25d86c5dba27c Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Mon, 22 Jul 2024 14:35:09 -0700 Subject: [PATCH] check original folder --- llama_toolchain/cli/inference/configure.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/llama_toolchain/cli/inference/configure.py b/llama_toolchain/cli/inference/configure.py index 593f02c09..9fcc97e1a 100644 --- a/llama_toolchain/cli/inference/configure.py +++ b/llama_toolchain/cli/inference/configure.py @@ -69,11 +69,12 @@ class InferenceConfigure(Subcommand): checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir = os.path.expanduser(checkpoint_dir) - if not ( - checkpoint_dir.endswith("original") or - checkpoint_dir.endswith("original/") - ): - checkpoint_dir = os.path.join(checkpoint_dir, "original") + # Check if checkpoint_dir contains "consolidated.00.pth" + # HF keeps the original pth files in a "original" folder + # so we need to check for that as well + if not (Path(checkpoint_dir) / "consolidated.00.pth" ).exists(): + if (Path(checkpoint_dir) / "original" / "consolidated.00.pth" ).exists(): + checkpoint_dir = os.path.join(checkpoint_dir, "original") os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"