temp commit

This commit is contained in:
Botao Chen 2024-12-02 17:24:25 -08:00
parent 6c709abc4d
commit 79c525be94
5 changed files with 115 additions and 163 deletions

View file

@ -41,7 +41,7 @@ class TrainingConfig(BaseModel):
gradient_accumulation_steps: int gradient_accumulation_steps: int
batch_size: int batch_size: int
shuffle: bool shuffle: bool
# n_iters: int optimizer_config: OptimizerConfig
enable_activation_checkpointing: bool enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: Optional[bool] memory_efficient_fsdp_wrap: Optional[bool]
@ -63,6 +63,7 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool apply_lora_to_output: bool
rank: int rank: int
alpha: int alpha: int
use_dora: bool
@json_schema_type @json_schema_type
@ -116,7 +117,6 @@ class PostTrainingSFTRequest(BaseModel):
algorithm: FinetuningAlgorithm algorithm: FinetuningAlgorithm
algorithm_config: LoraFinetuningConfig algorithm_config: LoraFinetuningConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig training_config: TrainingConfig
# TODO: define these # TODO: define these
@ -178,7 +178,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol): class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune") @webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
model: str, model: str,
@ -186,14 +186,14 @@ class PostTraining(Protocol):
validation_dataset_id: str, validation_dataset_id: str,
algorithm: FinetuningAlgorithm, algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig, algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig, # optimizer_config: OptimizerConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize") @webmethod(route="/post-training/preference-optimize")
def preference_optimize( async def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: URL, finetuned_model: URL,
@ -208,21 +208,23 @@ class PostTraining(Protocol):
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs") @webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ... async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/post-training/job/logs") @webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
def get_training_job_status( async def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts( async def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> PostTrainingJobArtifactsResponse: ...

View file

@ -20,7 +20,7 @@ class MetaReferencePostTrainingImpl:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
model: str, model: str,
@ -28,11 +28,11 @@ class MetaReferencePostTrainingImpl:
validation_dataset_id: str, validation_dataset_id: str,
algorithm: FinetuningAlgorithm, algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig, algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ) -> PostTrainingJob:
# wrapper request to make it easier to pass around (internal only, not exposed to API) # wrapper request to make it easier to pass around (internal only, not exposed to API)
request = PostTrainingSFTRequest( request = PostTrainingSFTRequest(
job_uuid=job_uuid, job_uuid=job_uuid,
@ -41,7 +41,6 @@ class MetaReferencePostTrainingImpl:
validation_dataset_id=validation_dataset_id, validation_dataset_id=validation_dataset_id,
algorithm=algorithm, algorithm=algorithm,
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
optimizer_config=optimizer_config,
training_config=training_config, training_config=training_config,
hyperparam_search_config=hyperparam_search_config, hyperparam_search_config=hyperparam_search_config,
logger_config=logger_config, logger_config=logger_config,
@ -50,14 +49,14 @@ class MetaReferencePostTrainingImpl:
recipe = LoraFinetuningSingleDevice( recipe = LoraFinetuningSingleDevice(
self.config, request, self.datasetio_api self.config, request, self.datasetio_api
) )
recipe.setup(self.config) await recipe.setup(self.config)
recipe.train() await recipe.train()
else: else:
raise NotImplementedError() raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid) return PostTrainingJob(job_uuid=job_uuid)
def preference_optimize( async def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
finetuned_model: URL, finetuned_model: URL,
@ -71,21 +70,24 @@ class MetaReferencePostTrainingImpl:
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
def get_training_jobs(self) -> List[PostTrainingJob]: ... # TODO @markchen1015 impelment below APIs
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/post-training/job/logs") @webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
def get_training_job_status( async def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts( async def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> PostTrainingJobArtifactsResponse: ...

View file

@ -7,15 +7,20 @@
import asyncio import asyncio
import logging import logging
import os import os
import re
from functools import partial from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from torch import nn from torch import nn
from torchtune import utils as torchtune_utils from torchtune import utils as torchtune_utils
from torchtune.training.checkpointing._utils import ModelType
from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.meta_reference import utils from llama_stack.providers.inline.post_training.meta_reference import utils
from llama_stack.providers.inline.post_training.meta_reference.config import ( from llama_stack.providers.inline.post_training.meta_reference.config import (
@ -47,14 +52,24 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice: class LoraFinetuningSingleDevice:
# This recipe only supports GPU training
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile
# - activation offloading
# Resume from checkpoint hasn't been supported yet
# Validation hasn't been supported yet
# TODO @markchen1015 figure out the logging for this training recipe
# and make it work with telemetry
def __init__( def __init__(
self, self,
config: MetaReferencePostTrainingConfig, config: MetaReferencePostTrainingConfig,
request: PostTrainingSFTRequest, request: PostTrainingSFTRequest,
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
) -> None: ) -> None:
# to make user config easier, assume the device is 'cuda' only # Assume the training only happens on GPU
# self._device = utils.get_device(device=cfg.device)
self.config = config self.config = config
self.request = request self.request = request
self._device = torchtune_utils.get_device(device="cuda") self._device = torchtune_utils.get_device(device="cuda")
@ -63,11 +78,30 @@ class LoraFinetuningSingleDevice:
) )
self.model_id = config.model self.model_id = config.model
# hardcode it for now and see how it works with get_training_job_artifacts def model_checkpoint_dir(model) -> str:
self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}" checkpoint_dir = Path(model_local_dir(model.descriptor()))
self._log_every_n_steps = 1 paths = [
self._log_peak_memory_stats = False Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
)
return str(checkpoint_dir)
if config.checkpoint_dir and config.checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir
else:
model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @markchen1015 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/"
self.seed = training.set_seed(seed=config.torch_seed or 42) self.seed = training.set_seed(seed=config.torch_seed or 42)
self.epochs_run = 0 self.epochs_run = 0
@ -75,23 +109,15 @@ class LoraFinetuningSingleDevice:
self._shuffle = request.training_config.shuffle self._shuffle = request.training_config.shuffle
self._batch_size = request.training_config.batch_size self._batch_size = request.training_config.batch_size
self.checkpoint_dir = (
self.config.checkpoint_dir or f"~/.llama/checkpoints/{self.model_id}"
)
# this is important for debugging purpose # this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
self.global_step = 0 self.global_step = 0
# not needed in MVP
# self._resume_from_checkpoint = cfg.resume_from_checkpoint
# self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._gradient_accumulation_steps = ( self._gradient_accumulation_steps = (
request.training_config.gradient_accumulation_steps request.training_config.gradient_accumulation_steps
) )
self._clip_grad_norm = 1.0 # hardcode self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = ( self._enable_activation_checkpointing = (
request.training_config.enable_activation_checkpointing request.training_config.enable_activation_checkpointing
) )
@ -99,12 +125,11 @@ class LoraFinetuningSingleDevice:
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
def load_checkpoint(self): async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> List[str]: def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
try: try:
# List all files in the given directory # List all files in the given directory
files = os.listdir(checkpoint_dir) files = os.listdir(checkpoint_dir)
# Filter files that end with .pth # Filter files that end with .pth
pth_files = [file for file in files if file.endswith(".pth")] pth_files = [file for file in files if file.endswith(".pth")]
return pth_files return pth_files
@ -115,44 +140,40 @@ class LoraFinetuningSingleDevice:
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir), checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir, output_dir=self._output_dir,
# todo: automatically get this info from model model_type=utils.get_checkpointer_model_type(self.model_id),
model_type="LLAMA3",
) )
checkpoint_dict = self._checkpointer.load_checkpoint() checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict return checkpoint_dict
def setup(self, config: MetaReferencePostTrainingConfig) -> None: async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
# todo: figure out how does it works with telemetry checkpoint_dict = await self.load_checkpoint()
# self._metric_logger = config.instantiate(cfg.metric_logger)
# self._metric_logger.log_config(cfg)
checkpoint_dict = self.load_checkpoint() self._model = await self._setup_model(
# hack to toggle to the low cpu ram version of the reparametrize_as_dtype
# hook based on the config.
# common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False)
# set up model
self._model = self._setup_model(
enable_activation_checkpointing=self._enable_activation_checkpointing, enable_activation_checkpointing=self._enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading, enable_activation_offloading=self._enable_activation_offloading,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY], base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=None, lora_weights_state_dict=None,
) )
log.info(f"Model is initialized with precision {self._dtype}.")
self._tokenizer = self._setup_tokenizer() self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized from file.") log.info("Tokenizer is initialized from file.")
self._optimizer = self._setup_optimizer( self._optimizer = await self._setup_optimizer(
optimizer_config=self.request.training_config.optimizer optimizer_config=self.request.training_config.optimizer_config
) )
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss() self._loss_fn = CEWithChunkedOutputLoss()
self._sampler, self._dataloader = self._setup_data( self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
self._sampler, self._dataloader = await self._setup_data(
tokenizer=self._tokenizer, tokenizer=self._tokenizer,
shuffle=self._shuffle, shuffle=self._shuffle,
batch_size=self._batch_size, batch_size=self._batch_size,
) )
log.info("Dataset and Sampler are initialized.")
# Number of training steps in each epoch depends on the number of batches produced # Number of training steps in each epoch depends on the number of batches produced
# by the dataloader and the max_steps_per_epoch param set by the user and is used # by the dataloader and the max_steps_per_epoch param set by the user and is used
@ -170,18 +191,19 @@ class LoraFinetuningSingleDevice:
# Learning rate scheduler can only be set up after number of steps # Learning rate scheduler can only be set up after number of steps
# has been computed # has been computed
self._lr_scheduler = self._setup_lr_scheduler( self._lr_scheduler = await self._setup_lr_scheduler(
num_warmup_steps=self.request.optimizer_config.num_warmup_steps, num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps,
num_training_steps=self.total_epochs * self._steps_per_epoch, num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1, last_epoch=self.global_step - 1,
) )
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation # Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full( self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device (self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
) )
def _setup_model( async def _setup_model(
self, self,
enable_activation_checkpointing: bool, enable_activation_checkpointing: bool,
enable_activation_offloading: bool, enable_activation_offloading: bool,
@ -243,9 +265,8 @@ class LoraFinetuningSingleDevice:
lora_missing=lora_missing, lora_missing=lora_missing,
lora_unexpected=lora_unexpected, lora_unexpected=lora_unexpected,
) )
# Validate model adapter params were loaded in with the expected dtype # Validate model adapter params were loaded in with the expected dtype
# TODO (rohan-varma): Further validation to ensure the appropriate base params
# are NF4 vs bf16 based on the quantization config.
training.validate_expected_param_dtype( training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype self.adapter_params.items(), dtype=self._dtype
) )
@ -254,22 +275,16 @@ class LoraFinetuningSingleDevice:
self.activations_handling_ctx = training.get_act_offloading_ctx_manager( self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading model, enable_activation_offloading
) )
log.info(f"Model is initialized with precision {self._dtype}.")
# if self._device.type != "cpu":
# memory_stats = training.get_memory_stats(device=self._device)
# training.log_memory_stats(memory_stats)
return model return model
def _setup_tokenizer( async def _setup_tokenizer(
self, self,
) -> Llama3Tokenizer: ) -> Llama3Tokenizer:
tokenizer_path = self.checkpoint_dir + "/tokenizer.model" tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
tokenizer_type = utils.get_tokenizer_type(self.model_id) tokenizer_type = utils.get_tokenizer_type(self.model_id)
return tokenizer_type(path=tokenizer_path) return tokenizer_type(path=tokenizer_path)
def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer: async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
optimizer = torch.optim.AdamW( optimizer = torch.optim.AdamW(
params=self._model.parameters(), params=self._model.parameters(),
lr=optimizer_config.lr, lr=optimizer_config.lr,
@ -277,11 +292,9 @@ class LoraFinetuningSingleDevice:
eps=1e-8, eps=1e-8,
weight_decay=0.1, weight_decay=0.1,
) )
log.info("Optimizer and loss are initialized.")
return optimizer return optimizer
def _setup_data( async def _setup_data(
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(): async def fetch_rows():
@ -290,10 +303,11 @@ class LoraFinetuningSingleDevice:
rows_in_page=-1, rows_in_page=-1,
) )
# Run the async function in an event loop all_rows = await fetch_rows()
all_rows = asyncio.run(fetch_rows())
rows = all_rows.rows rows = all_rows.rows
# Curretly only support instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types
ds = SFTDataset( ds = SFTDataset(
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
) )
@ -320,11 +334,9 @@ class LoraFinetuningSingleDevice:
), ),
) )
log.info("Dataset and Sampler are initialized.")
return sampler, dataloader return sampler, dataloader
def _setup_lr_scheduler( async def _setup_lr_scheduler(
self, self,
num_warmup_steps: int, num_warmup_steps: int,
num_training_steps: int, num_training_steps: int,
@ -332,33 +344,19 @@ class LoraFinetuningSingleDevice:
) -> Optimizer: ) -> Optimizer:
lr_scheduler = get_cosine_schedule_with_warmup( lr_scheduler = get_cosine_schedule_with_warmup(
self._optimizer, self._optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps, num_training_steps=num_training_steps,
last_epoch=last_epoch, last_epoch=last_epoch,
) )
log.info("Learning rate scheduler is initialized.")
return lr_scheduler return lr_scheduler
def save_checkpoint(self, epoch: int) -> None: async def save_checkpoint(self, epoch: int) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
"""
ckpt_dict = {} ckpt_dict = {}
intermediate_checkpoint = epoch + 1 < self.total_epochs
adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
# Construct the full state dict with LoRA weights merged into base LLM weights # Construct the full state dict with LoRA weights merged into base LLM weights
# Move to CPU to avoid a copy on GPU # Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
@ -385,10 +383,9 @@ class LoraFinetuningSingleDevice:
self._checkpointer.save_checkpoint( self._checkpointer.save_checkpoint(
ckpt_dict, ckpt_dict,
epoch=epoch, epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
) )
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model # Shape [b, s], needed for the loss not the model
labels = batch.pop("labels") labels = batch.pop("labels")
# run model # run model
@ -412,16 +409,10 @@ class LoraFinetuningSingleDevice:
return loss return loss
def train(self) -> None: async def train(self) -> None:
""" """
The core training loop. The core training loop.
""" """
# if self._compile:
# log.info(
# "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
# )
# Initialize tokens count and running loss (for grad accumulation) # Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter() # t0 = time.perf_counter()
running_loss = 0 running_loss = 0
@ -433,7 +424,6 @@ class LoraFinetuningSingleDevice:
# in case shuffle is True # in case shuffle is True
self._sampler.set_epoch(curr_epoch) self._sampler.set_epoch(curr_epoch)
# pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._dataloader): for idx, batch in enumerate(self._dataloader):
if ( if (
self.max_steps_per_epoch is not None self.max_steps_per_epoch is not None
@ -442,14 +432,6 @@ class LoraFinetuningSingleDevice:
): ):
break break
# Start tracking CUDA memory for active steps for just the first epoch
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx == self.profiler_wait_steps + self.profiler_warmup_steps
# ):
# torch.cuda.memory._record_memory_history()
torchtune_utils.batch_to_device(batch, self._device) torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch # Calculate the number of unmasked tokens in the current batch
@ -461,14 +443,14 @@ class LoraFinetuningSingleDevice:
# Loss is normalized by default so we multiply by the number of tokens # Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients # This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_step(batch) * current_num_tokens current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss running_loss += current_loss
current_loss.backward() current_loss.backward()
# Step with optimizer # Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0: if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens) training.scale_grads(self._model, 1 / num_tokens)
grad_norm = torch.nn.utils.clip_grad_norm_( torch.nn.utils.clip_grad_norm_(
self._model.parameters(), self._model.parameters(),
max_norm=float(self._clip_grad_norm), max_norm=float(self._clip_grad_norm),
) )
@ -478,58 +460,10 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated # Update the number of steps when the weights are updated
self.global_step += 1 self.global_step += 1
# loss_to_log = running_loss.item() / num_tokens
# pbar.update(1)
# pbar.set_description(
# f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
# )
# Log per-step metrics
# if self.global_step % self._log_every_n_steps == 0:
# time_per_step = time.perf_counter() - t0
# log_dict = {
# "loss": loss_to_log,
# "lr": self._optimizer.param_groups[0]["lr"],
# "tokens_per_second_per_gpu": num_tokens / time_per_step,
# }
# if self._device.type == "cuda" and self._log_peak_memory_stats:
# log_dict.update(
# training.get_memory_stats(device=self._device)
# )
# if self._clip_grad_norm is not None:
# log_dict.update({"grad_norm": grad_norm})
# self._metric_logger.log_dict(
# log_dict,
# step=self.global_step,
# )
# Reset running stats for the next step # Reset running stats for the next step
running_loss = 0 running_loss = 0
num_tokens = 0 num_tokens = 0
# t0 = time.perf_counter()
# Stop tracking CUDA memory now that active steps are complete
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx
# == self.profiler_wait_steps
# + self.profiler_warmup_steps
# + self.profiler_active_steps
# ):
# torch.cuda.memory._record_memory_history(enabled=None)
# Step the profiler
# Note we are stepping each batch, which might not include optimizer step in the trace
# if the schedule cycle doesn't align with gradient accumulation.
# prof.step()
self.epochs_run += 1 self.epochs_run += 1
# start_save_checkpoint = time.perf_counter()
log.info("Starting checkpoint save...") log.info("Starting checkpoint save...")
self.save_checkpoint(epoch=curr_epoch) await self.save_checkpoint(epoch=curr_epoch)
# log.info(
# "Checkpoint saved in {:.2f} seconds.".format(
# time.perf_counter() - start_save_checkpoint
# )
# )

View file

@ -16,15 +16,22 @@ import torch
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b
LORA_MODEL_TYPES: Dict[str, Any] = { LORA_MODEL_TYPES: Dict[str, Any] = {
"Llama3.2-3B-Instruct": lora_llama3_2_3b,
"Llama-3-8B-Instruct": lora_llama3_8b, "Llama-3-8B-Instruct": lora_llama3_8b,
} }
TOKENIZER_TYPES: Dict[str, Any] = { TOKENIZER_TYPES: Dict[str, Any] = {
"Llama3.2-3B-Instruct": llama3_tokenizer,
"Llama-3-8B-Instruct": llama3_tokenizer, "Llama-3-8B-Instruct": llama3_tokenizer,
} }
CHECKPOINT_MODEL_TYPES: Dict[str, str] = {
"Llama3.2-3B-Instruct": "LLAMA3_2",
}
BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer] BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
@ -41,3 +48,10 @@ def get_tokenizer_type(
) -> BuildTokenizerCallable: ) -> BuildTokenizerCallable:
model = resolve_model(model_id) model = resolve_model(model_id)
return TOKENIZER_TYPES[model.core_model_id.value] return TOKENIZER_TYPES[model.core_model_id.value]
def get_checkpointer_model_type(
model_id: str,
) -> str:
model = resolve_model(model_id)
return CHECKPOINT_MODEL_TYPES[model.core_model_id.value]

View file

@ -71,7 +71,7 @@ datasets:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata: metadata:
path: tatsu-lab/alpaca path: tatsu-lab/alpaca
name: post_training_alpaca name:
split: train split: train
dataset_schema: dataset_schema:
instruction: instruction: