mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: proper checkpointing logic for HF trainer
currently only the last saved model is reported as a checkpoint and associated with the job UUID. since the HF trainer handles checkpoint collection during training, we need to add all of the `checkpoint-*` folders as Checkpoint objects. Adjust the save strategy to be per-epoch to make this easier and to use less storage Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
1d3f27fe5b
commit
d6228bb90e
1 changed files with 36 additions and 18 deletions
|
@ -442,22 +442,20 @@ class HFFinetuningSingleDevice:
|
|||
# Calculate steps
|
||||
total_steps = steps_per_epoch * config.n_epochs
|
||||
max_steps = min(config.max_steps_per_epoch, total_steps)
|
||||
eval_steps = max(1, steps_per_epoch // 10) # Evaluate 10 times per epoch
|
||||
save_steps = max(1, steps_per_epoch // 5) # Save 5 times per epoch
|
||||
logging_steps = max(1, steps_per_epoch // 50) # Log 50 times per epoch
|
||||
|
||||
logger.info("Training configuration:")
|
||||
logger.info(f"- Steps per epoch: {steps_per_epoch}")
|
||||
logger.info(f"- Total steps: {total_steps}")
|
||||
logger.info(f"- Max steps: {max_steps}")
|
||||
logger.info(f"- Eval steps: {eval_steps}")
|
||||
logger.info(f"- Save steps: {save_steps}")
|
||||
logger.info(f"- Logging steps: {logging_steps}")
|
||||
|
||||
# Configure save strategy
|
||||
save_strategy = "no"
|
||||
eval_strategy = "no"
|
||||
if output_dir_path:
|
||||
save_strategy = "steps"
|
||||
save_strategy = "epoch"
|
||||
eval_strategy = "epoch"
|
||||
logger.info(f"Will save checkpoints to {output_dir_path}")
|
||||
|
||||
return SFTConfig(
|
||||
|
@ -467,7 +465,7 @@ class HFFinetuningSingleDevice:
|
|||
per_device_train_batch_size=data_config.batch_size,
|
||||
fp16=device.type == "cuda",
|
||||
bf16=False, # Causes CPU issues.
|
||||
eval_strategy="steps",
|
||||
eval_strategy=eval_strategy,
|
||||
use_cpu=True if device.type == "cpu" and not torch.backends.mps.is_available() else False,
|
||||
save_strategy=save_strategy,
|
||||
report_to="none",
|
||||
|
@ -485,8 +483,6 @@ class HFFinetuningSingleDevice:
|
|||
load_best_model_at_end=True if output_dir_path else False,
|
||||
metric_for_best_model="eval_loss",
|
||||
greater_is_better=False,
|
||||
eval_steps=eval_steps,
|
||||
save_steps=save_steps,
|
||||
logging_steps=logging_steps,
|
||||
)
|
||||
|
||||
|
@ -665,19 +661,41 @@ class HFFinetuningSingleDevice:
|
|||
|
||||
memory_stats["after_training"] = get_memory_stats(device)
|
||||
|
||||
checkpoints = None
|
||||
checkpoints = []
|
||||
if output_dir_path:
|
||||
# Create checkpoint
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{model}-sft-{config.n_epochs}",
|
||||
created_at=datetime.now(UTC),
|
||||
epoch=config.n_epochs,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(output_dir_path / "merged_model"),
|
||||
# Get all checkpoint directories and sort them numerically
|
||||
checkpoint_dirs = sorted(
|
||||
[d for d in output_dir_path.glob("checkpoint-*") if d.is_dir()],
|
||||
key=lambda x: int(x.name.split("-")[1]),
|
||||
)
|
||||
checkpoints = [checkpoint]
|
||||
|
||||
return memory_stats, checkpoints
|
||||
# Add all checkpoint directories
|
||||
for epoch_number, checkpoint_dir in enumerate(checkpoint_dirs, start=1):
|
||||
# Get the creation time of the directory
|
||||
created_time = datetime.fromtimestamp(os.path.getctime(checkpoint_dir), tz=UTC)
|
||||
|
||||
checkpoint = Checkpoint(
|
||||
identifier=checkpoint_dir.name,
|
||||
created_at=created_time,
|
||||
epoch=epoch_number,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(checkpoint_dir),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
# Add the merged model as a checkpoint
|
||||
merged_model_path = output_dir_path / "merged_model"
|
||||
if merged_model_path.exists():
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{model}-sft-{config.n_epochs}",
|
||||
created_at=datetime.now(UTC),
|
||||
epoch=config.n_epochs,
|
||||
post_training_job_id=job_uuid,
|
||||
path=str(merged_model_path),
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return memory_stats, checkpoints if checkpoints else None
|
||||
finally:
|
||||
memory_stats["final"] = get_memory_stats(device)
|
||||
gc.collect()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue