diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 67de380c0..bff55e017 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -547,10 +547,11 @@ class LoraFinetuningSingleDevice: checkpoints.append(checkpoint) # clean up the memory after training finishes - self._model.to("cpu") + if self._device.type != "cpu": + self._model.to("cpu") + torch.cuda.empty_cache() del self._model gc.collect() - torch.cuda.empty_cache() return (memory_stats, checkpoints)