From 317e80dc2cd529a9b1ea6e709c8b193a48ecebb2 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 12 Dec 2024 15:41:39 -0800 Subject: [PATCH] refine --- llama_stack/apis/common/training_types.py | 2 +- .../torchtune/recipes/lora_finetuning_single_device.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index a9e3cac7c..b4bd1b0c6 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -26,4 +26,4 @@ class Checkpoint(BaseModel): epoch: int post_training_job_id: str path: str - training_metric: Optional[PostTrainingMetric] = None + training_metrics: Optional[PostTrainingMetric] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index ec72fedb5..b832d40ec 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -544,13 +544,13 @@ class LoraFinetuningSingleDevice: ) if self.training_config.data_config.validation_dataset_id: validation_loss, perplexity = await self.validation() - training_metreic = PostTrainingMetric( + training_metrics = PostTrainingMetric( epoch=curr_epoch, train_loss=loss_to_log, validation_loss=validation_loss, perplexity=perplexity, ) - checkpoint.training_metric = training_metreic + checkpoint.training_metrics = training_metrics checkpoints.append(checkpoint) return (memory_stats, checkpoints)