fix all iterrows callsites

This commit is contained in:
Xi Yan 2025-03-15 17:14:10 -07:00
parent f117407af6
commit b561cfd902
6 changed files with 56 additions and 163 deletions

View file

@ -17,7 +17,8 @@ import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
@ -88,9 +89,7 @@ class LoraFinetuningSingleDevice:
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -99,10 +98,7 @@ class LoraFinetuningSingleDevice:
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
@ -117,9 +113,7 @@ class LoraFinetuningSingleDevice:
else:
model = resolve_model(self.model_id)
if model is None:
raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -191,9 +185,7 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer(
optimizer_config=self.training_config.optimizer_config
)
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
@ -221,13 +213,8 @@ class LoraFinetuningSingleDevice:
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = (
len(self._training_dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
@ -241,9 +228,7 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
@ -284,13 +269,9 @@ class LoraFinetuningSingleDevice:
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
@ -299,9 +280,7 @@ class LoraFinetuningSingleDevice:
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
)
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
else:
lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora(
@ -315,14 +294,10 @@ class LoraFinetuningSingleDevice:
)
# Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
self._log_memory_stats()
@ -458,9 +433,7 @@ class LoraFinetuningSingleDevice:
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
@ -489,9 +462,7 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
)
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
@ -499,8 +470,7 @@ class LoraFinetuningSingleDevice:
for idx, batch in enumerate(self._training_dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
):
break
@ -508,9 +478,7 @@ class LoraFinetuningSingleDevice:
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
@ -535,9 +503,7 @@ class LoraFinetuningSingleDevice:
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
time_per_step = time.perf_counter() - t0
log_dict = {