mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 10:54:19 +00:00
feat: Enable CPU training for torchtune (#1140)
# What does this PR do? You are now able to run a training cycle on CPU. This is useful for debugging and testing purposes. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan On a Mac machine without CUDA devices: ``` 17:00:24.417 [START] /v1/post-training/supervised-fine-tune DEBUG 2025-02-18 12:00:24,419 torchtune.utils._logging:60: Setting manual seed to local seed 3268931494. Local seed is seed + rank = 3268931494 + 0 INFO 2025-02-18 12:00:24,463 torchtune.utils._logging:64: Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights. INFO 2025-02-18 12:00:46,699 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:182: Model is initialized with precision torch.bfloat16. INFO 2025-02-18 12:00:46,784 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:185: Tokenizer is initialized. INFO 2025-02-18 12:00:46,786 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:188: Optimizer is initialized. INFO 2025-02-18 12:00:46,786 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:192: Loss is initialized. INFO 2025-02-18 12:00:48,997 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:209: Dataset and Sampler are initialized. INFO 2025-02-18 12:00:48,998 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:227: Learning rate scheduler is initialized. Writing logs to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/log_1739898049.txt 1|1|Loss: 1.7414989471435547: 100% 1/1 [03:46<00:00, 226.21s/it]INFO 2025-02-18 12:04:35,227 llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device:528: Starting checkpoint save... INFO 2025-02-18 12:04:49,974 torchtune.utils._logging:121: Model checkpoint of size 6.43 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth INFO 2025-02-18 12:04:49,981 torchtune.utils._logging:132: Adapter checkpoint of size 0.00 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth model_file_path /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0 1|1|Loss: 1.7414989471435547: 100% 1/1 [04:01<00:00, 241.18s/it] INFO: ::1:64990 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 200 OK 17:04:50.364 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (265947.01ms) 17:00:24.419 [DEBUG] Setting manual seed to local seed 3268931494. Local seed is seed + rank = 3268931494 + 0 17:00:24.463 [INFO] Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights. 17:00:46.700 [INFO] Model is initialized with precision torch.bfloat16. 17:00:46.784 [INFO] Tokenizer is initialized. 17:00:46.786 [INFO] Optimizer is initialized. 17:00:46.786 [INFO] Loss is initialized. 17:00:48.997 [INFO] Dataset and Sampler are initialized. 17:00:48.998 [INFO] Learning rate scheduler is initialized. 17:04:35.227 [INFO] Starting checkpoint save... 17:04:49.974 [INFO] Model checkpoint of size 6.43 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth 17:04:49.981 [INFO] Adapter checkpoint of size 0.00 GB saved to /Users/ihrachys/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth ``` [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
This commit is contained in:
parent
a324ceb9a9
commit
fb6a3efb1d
1 changed files with 10 additions and 7 deletions
|
@ -64,8 +64,6 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
|
|
||||||
|
|
||||||
class LoraFinetuningSingleDevice:
|
class LoraFinetuningSingleDevice:
|
||||||
# This recipe only supports GPU training
|
|
||||||
|
|
||||||
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
||||||
# - compile
|
# - compile
|
||||||
# - activation offloading
|
# - activation offloading
|
||||||
|
@ -93,7 +91,7 @@ class LoraFinetuningSingleDevice:
|
||||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
|
||||||
self.algorithm_config = algorithm_config
|
self.algorithm_config = algorithm_config
|
||||||
self._device = torchtune_utils.get_device(device="cuda")
|
self._device = torchtune_utils.get_device()
|
||||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||||
self.model_id = model
|
self.model_id = model
|
||||||
|
|
||||||
|
@ -231,6 +229,13 @@ class LoraFinetuningSingleDevice:
|
||||||
# Used to ignore labels for loss computation
|
# Used to ignore labels for loss computation
|
||||||
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
|
||||||
|
|
||||||
|
def _log_memory_stats(self):
|
||||||
|
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
|
||||||
|
if self._device.type == "cpu":
|
||||||
|
return
|
||||||
|
memory_stats = training.get_memory_stats(device=self._device)
|
||||||
|
training.log_memory_stats(memory_stats)
|
||||||
|
|
||||||
async def _setup_model(
|
async def _setup_model(
|
||||||
self,
|
self,
|
||||||
enable_activation_checkpointing: bool,
|
enable_activation_checkpointing: bool,
|
||||||
|
@ -293,8 +298,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# activation offloading
|
# activation offloading
|
||||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
|
||||||
|
|
||||||
memory_stats = training.get_memory_stats(device=self._device)
|
self._log_memory_stats()
|
||||||
training.log_memory_stats(memory_stats)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -506,8 +510,7 @@ class LoraFinetuningSingleDevice:
|
||||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||||
}
|
}
|
||||||
|
|
||||||
memory_stats = training.get_memory_stats(device=self._device)
|
self._log_memory_stats()
|
||||||
log_dict.update(memory_stats)
|
|
||||||
|
|
||||||
if self._clip_grad_norm is not None:
|
if self._clip_grad_norm is not None:
|
||||||
log_dict.update({"grad_norm": grad_norm})
|
log_dict.update({"grad_norm": grad_norm})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue