feat: support cpu training for torchtune

Before the patch, cpu training attempt was failing because CPU devices
don't support memory stats:

504cbea31e/torchtune/training/memory.py (L264)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-02-18 12:03:30 -05:00
parent 6b1773d530
commit d4b6ddf96c

View file

@ -64,8 +64,6 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice:
# This recipe only supports GPU training
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile
# - activation offloading
@ -93,7 +91,7 @@ class LoraFinetuningSingleDevice:
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
self.model_id = model
@ -231,6 +229,13 @@ class LoraFinetuningSingleDevice:
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
if self._device.type == "cpu":
return
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
async def _setup_model(
self,
enable_activation_checkpointing: bool,
@ -293,8 +298,7 @@ class LoraFinetuningSingleDevice:
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
self._log_memory_stats()
return model
@ -506,8 +510,7 @@ class LoraFinetuningSingleDevice:
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
self._log_memory_stats()
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})