check original folder

This commit is contained in:
Hardik Shah 2024-07-22 14:35:09 -07:00
parent 74442e88b1
commit 4d3b226275

View file

@ -69,11 +69,12 @@ class InferenceConfigure(Subcommand):
checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir) checkpoint_dir = os.path.expanduser(checkpoint_dir)
if not ( # Check if checkpoint_dir contains "consolidated.00.pth"
checkpoint_dir.endswith("original") or # HF keeps the original pth files in a "original" folder
checkpoint_dir.endswith("original/") # so we need to check for that as well
): if not (Path(checkpoint_dir) / "consolidated.00.pth" ).exists():
checkpoint_dir = os.path.join(checkpoint_dir, "original") if (Path(checkpoint_dir) / "original" / "consolidated.00.pth" ).exists():
checkpoint_dir = os.path.join(checkpoint_dir, "original")
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml" yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"