diff --git a/llama_toolchain/cli/inference/configure.py b/llama_toolchain/cli/inference/configure.py index 9fcc97e1a..5d2596315 100644 --- a/llama_toolchain/cli/inference/configure.py +++ b/llama_toolchain/cli/inference/configure.py @@ -69,12 +69,8 @@ class InferenceConfigure(Subcommand): checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir = os.path.expanduser(checkpoint_dir) - # Check if checkpoint_dir contains "consolidated.00.pth" - # HF keeps the original pth files in a "original" folder - # so we need to check for that as well - if not (Path(checkpoint_dir) / "consolidated.00.pth" ).exists(): - if (Path(checkpoint_dir) / "original" / "consolidated.00.pth" ).exists(): - checkpoint_dir = os.path.join(checkpoint_dir, "original") + assert Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir(), \ + f"{checkpoint_dir} does not exist or it not a directory" os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"