diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 4ab59fec4..67de380c0 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -64,8 +64,6 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer class LoraFinetuningSingleDevice: - # This recipe only supports GPU training - # This recipe doesn't include several training efficiency setting within origin torchtune repo, including # - compile # - activation offloading @@ -93,7 +91,7 @@ class LoraFinetuningSingleDevice: if not isinstance(algorithm_config, LoraFinetuningConfig): raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning") self.algorithm_config = algorithm_config - self._device = torchtune_utils.get_device(device="cuda") + self._device = torchtune_utils.get_device() self._dtype = training.get_dtype(training_config.dtype, device=self._device) self.model_id = model @@ -231,6 +229,13 @@ class LoraFinetuningSingleDevice: # Used to ignore labels for loss computation self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device) + def _log_memory_stats(self): + # torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing + if self._device.type == "cpu": + return + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + async def _setup_model( self, enable_activation_checkpointing: bool, @@ -293,8 +298,7 @@ class LoraFinetuningSingleDevice: # activation offloading self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading) - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) + self._log_memory_stats() return model @@ -506,8 +510,7 @@ class LoraFinetuningSingleDevice: "tokens_per_second_per_gpu": num_tokens / time_per_step, } - memory_stats = training.get_memory_stats(device=self._device) - log_dict.update(memory_stats) + self._log_memory_stats() if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm})