no special casign for original

This commit is contained in:
Hardik Shah 2024-07-22 14:42:38 -07:00
parent 4d3b226275
commit 441e5da6ed

View file

@ -69,12 +69,8 @@ class InferenceConfigure(Subcommand):
checkpoint_dir, model_parallel_size = self.read_user_inputs() checkpoint_dir, model_parallel_size = self.read_user_inputs()
checkpoint_dir = os.path.expanduser(checkpoint_dir) checkpoint_dir = os.path.expanduser(checkpoint_dir)
# Check if checkpoint_dir contains "consolidated.00.pth" assert Path(checkpoint_dir).exists() and Path(checkpoint_dir).is_dir(), \
# HF keeps the original pth files in a "original" folder f"{checkpoint_dir} does not exist or it not a directory"
# so we need to check for that as well
if not (Path(checkpoint_dir) / "consolidated.00.pth" ).exists():
if (Path(checkpoint_dir) / "original" / "consolidated.00.pth" ).exists():
checkpoint_dir = os.path.join(checkpoint_dir, "original")
os.makedirs(CONFIGS_BASE_DIR, exist_ok=True) os.makedirs(CONFIGS_BASE_DIR, exist_ok=True)
yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml" yaml_output_path = Path(CONFIGS_BASE_DIR) / "inference.yaml"