mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-09 19:58:29 +00:00
feat: support cpu training for torchtune
Before the patch, cpu training attempt was failing because CPU devices
don't support memory stats:
504cbea31e/torchtune/training/memory.py (L264)
Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
6b1773d530
commit
d4b6ddf96c
1 changed files with 10 additions and 7 deletions
|
@ -64,8 +64,6 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
# This recipe only supports GPU training
|
||||
|
||||
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
||||
# - compile
|
||||
# - activation offloading
|
||||
|
@ -93,7 +91,7 @@ class LoraFinetuningSingleDevice:
|
|||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._device = torchtune_utils.get_device()
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
self.model_id = model
|
||||
|
||||
|
@ -231,6 +229,13 @@ class LoraFinetuningSingleDevice:
|
|||
# Used to ignore labels for loss computation
|
||||
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||
|
||||
def _log_memory_stats(self):
|
||||
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
|
||||
if self._device.type == "cpu":
|
||||
return
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
|
||||
async def _setup_model(
|
||||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
|
@ -293,8 +298,7 @@ class LoraFinetuningSingleDevice:
|
|||
# activation offloading
|
||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
self._log_memory_stats()
|
||||
|
||||
return model
|
||||
|
||||
|
@ -506,8 +510,7 @@ class LoraFinetuningSingleDevice:
|
|||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
}
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
log_dict.update(memory_stats)
|
||||
self._log_memory_stats()
|
||||
|
||||
if self._clip_grad_norm is not None:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue