fix all iterrows callsites

This commit is contained in:
Xi Yan 2025-03-15 17:11:52 -07:00
parent f2d93324e9
commit f117407af6
6 changed files with 173 additions and 67 deletions

View file

@ -17,8 +17,7 @@ import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
@ -89,7 +88,9 @@ class LoraFinetuningSingleDevice:
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError("You need to speicifc LoraFinetuningConfig for LoRA finetuning")
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device()
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -98,7 +99,10 @@ class LoraFinetuningSingleDevice:
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
@ -113,7 +117,9 @@ class LoraFinetuningSingleDevice:
else:
model = resolve_model(self.model_id)
if model is None:
raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list")
raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -185,7 +191,9 @@ class LoraFinetuningSingleDevice:
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config)
self._optimizer = await self._setup_optimizer(
optimizer_config=self.training_config.optimizer_config
)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
@ -213,8 +221,13 @@ class LoraFinetuningSingleDevice:
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = len(self._training_dataloader) // self._gradient_accumulation_steps
if self.max_steps_per_epoch is not None and self.max_steps_per_epoch < self._steps_per_epoch:
self._steps_per_epoch = (
len(self._training_dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
@ -228,7 +241,9 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full((self._batch_size, 1), self._loss_fn.ignore_index, device=self._device)
self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
def _log_memory_stats(self):
# torchtune raises: "Logging memory stats is not supported on CPU devices"; do nothing
@ -269,9 +284,13 @@ class LoraFinetuningSingleDevice:
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
training.set_activation_checkpointing(model, auto_wrap_policy={modules.TransformerSelfAttentionLayer})
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
base_missing, base_unexpected = model.load_state_dict(base_model_state_dict, strict=False)
base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
@ -280,7 +299,9 @@ class LoraFinetuningSingleDevice:
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(lora_weights_state_dict, strict=False)
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
)
else:
lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora(
@ -294,10 +315,14 @@ class LoraFinetuningSingleDevice:
)
# Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(self.adapter_params.items(), dtype=self._dtype)
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(model, enable_activation_offloading)
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
self._log_memory_stats()
@ -330,11 +355,11 @@ class LoraFinetuningSingleDevice:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows
rows = all_rows.data
await validate_input_dataset_schema(
datasets_api=self.datasets_api,
@ -433,7 +458,9 @@ class LoraFinetuningSingleDevice:
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack((labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]))
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
@ -462,7 +489,9 @@ class LoraFinetuningSingleDevice:
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log")
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
)
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
@ -470,7 +499,8 @@ class LoraFinetuningSingleDevice:
for idx, batch in enumerate(self._training_dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps) == self.max_steps_per_epoch
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
@ -478,7 +508,9 @@ class LoraFinetuningSingleDevice:
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
@ -503,7 +535,9 @@ class LoraFinetuningSingleDevice:
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}")
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
time_per_step = time.perf_counter() - t0
log_dict = {