check original folder

This commit is contained in:
Hardik Shah 2024-07-22 14:35:09 -07:00
parent 74442e88b1
commit 4d3b226275

View file

@ -69,11 +69,12 @@ class InferenceConfigure(Subcommand):
checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir)
if not (
checkpoint_dir.endswith("original") or
checkpoint_dir.endswith("original/")
):
checkpoint_dir = os.path.join(checkpoint_dir, "original")
# Check if checkpoint_dir contains "consolidated.00.pth"
# HF keeps the original pth files in a "original" folder
# so we need to check for that as well
if not (Path(checkpoint_dir) / "consolidated.00.pth" ).exists():
if (Path(checkpoint_dir) / "original" / "consolidated.00.pth" ).exists():
checkpoint_dir = os.path.join(checkpoint_dir, "original")
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"