diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 6e5ec6050..41387474f 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -549,10 +549,11 @@ class LoraFinetuningSingleDevice: checkpoints.append(checkpoint) # clean up the memory after training finishes - self._model.to("cpu") + if self._device.type != "cpu": + self._model.to("cpu") + torch.cuda.empty_cache() del self._model gc.collect() - torch.cuda.empty_cache() return (memory_stats, checkpoints)