feat: Enable CPU training for torchtune (#1140)

# What does this PR do?

You are now able to run a training cycle on CPU. This is useful for
debugging and testing purposes.

[//]: # (If resolving an issue, uncomment and update the line below)
[//]: # (Closes #[issue-number])

## Test Plan

On a Mac machine without CUDA devices:

```
17:00:24.417 [START] /v1/post-training/supervised-fine-tune
DEBUG 2025-02-18 12:00:24,419 torchtune.utils._logging:60: Setting manual seed to local seed 3268931494. Local seed is seed + rank = 3268931494 + 0
INFO 2025-02-18 12:00:24,463 torchtune.utils._logging:64: Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights.
INFO 2025-02-18 12:00:46,699 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:182: Model is initialized with precision torch.bfloat16.
INFO 2025-02-18 12:00:46,784 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:185: Tokenizer is initialized.
INFO 2025-02-18 12:00:46,786 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:188: Optimizer is initialized.
INFO 2025-02-18 12:00:46,786 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:192: Loss is initialized.
INFO 2025-02-18 12:00:48,997 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:209: Dataset and Sampler are initialized.
INFO 2025-02-18 12:00:48,998 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:227: Learning rate scheduler is initialized.
Writing logs to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/log_1739898049.txt
1|1|Loss: 1.7414989471435547: 100% 1/1 [03:46<00:00, 226.21s/it]INFO 2025-02-18 12:04:35,227 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:528: Starting checkpoint save...
INFO 2025-02-18 12:04:49,974 torchtune.utils._logging:121: Model checkpoint of size 6.43 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth
INFO 2025-02-18 12:04:49,981 torchtune.utils._logging:132: Adapter checkpoint of size 0.00 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth
model_file_path /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0
1|1|Loss: 1.7414989471435547: 100% 1/1 [04:01<00:00, 241.18s/it]
INFO:     ::1:64990 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 200 OK
17:04:50.364 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (265947.01ms)
 17:00:24.419 [DEBUG] Setting manual seed to local seed 3268931494. Local seed is seed + rank = 3268931494 + 0
 17:00:24.463 [INFO] Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights.
 17:00:46.700 [INFO] Model is initialized with precision torch.bfloat16.
 17:00:46.784 [INFO] Tokenizer is initialized.
 17:00:46.786 [INFO] Optimizer is initialized.
 17:00:46.786 [INFO] Loss is initialized.
 17:00:48.997 [INFO] Dataset and Sampler are initialized.
 17:00:48.998 [INFO] Learning rate scheduler is initialized.
 17:04:35.227 [INFO] Starting checkpoint save...
 17:04:49.974 [INFO] Model checkpoint of size 6.43 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth
 17:04:49.981 [INFO] Adapter checkpoint of size 0.00 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth
```

[//]: # (## Documentation)

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
Ihar Hrachyshka 2025-02-20 01:42:58 -05:00 committed by GitHub
parent a324ceb9a9
commit fb6a3efb1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -64,8 +64,6 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice: class LoraFinetuningSingleDevice:
# This recipe only supports GPU training
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including # This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile # - compile
# - activation offloading # - activation offloading
@ -93,7 +91,7 @@ class LoraFinetuningSingleDevice:
if not isinstance(algorithm_config, LoraFinetuningConfig): if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning") raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
self.algorithm_config = algorithm_config self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda") self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device) self._dtype = training.get_dtype(training_config.dtype, device=self._device)
self.model_id = model self.model_id = model
@ -231,6 +229,13 @@ class LoraFinetuningSingleDevice:
# Used to ignore labels for loss computation # Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device) self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
if self._device.type == "cpu":
return
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
async def _setup_model( async def _setup_model(
self, self,
enable_activation_checkpointing: bool, enable_activation_checkpointing: bool,
@ -293,8 +298,7 @@ class LoraFinetuningSingleDevice:
# activation offloading # activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading) self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
memory_stats = training.get_memory_stats(device=self._device) self._log_memory_stats()
training.log_memory_stats(memory_stats)
return model return model
@ -506,8 +510,7 @@ class LoraFinetuningSingleDevice:
"tokens_per_second_per_gpu": num_tokens / time_per_step, "tokens_per_second_per_gpu": num_tokens / time_per_step,
} }
memory_stats = training.get_memory_stats(device=self._device) self._log_memory_stats()
log_dict.update(memory_stats)
if self._clip_grad_norm is not None: if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm}) log_dict.update({"grad_norm": grad_norm})