diff --git a/toolchain/cli/download.py b/toolchain/cli/download.py index edb1eb3a3..a905cca55 100644 --- a/toolchain/cli/download.py +++ b/toolchain/cli/download.py @@ -63,9 +63,11 @@ class Download(Subcommand): os.makedirs(DEFAULT_OUTPUT_DIR, exist_ok=True) output_dir = args.output_dir + model_name = args.repo_id.split("/")[-1] if output_dir is None: - model_name = args.repo_id.split("/")[-1] output_dir = Path(DEFAULT_OUTPUT_DIR) / model_name + else: + output_dir = Path(output_dir) / model_name try: true_output_dir = snapshot_download(