From 79c525be948425c4e5e05246db89c40dbe8802c7 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 2 Dec 2024 17:24:25 -0800 Subject: [PATCH] temp commit --- .../apis/post_training/post_training.py | 22 +- .../meta_reference/post_training.py | 24 +- .../recipes/lora_finetuning_single_device.py | 216 ++++++------------ .../post_training/meta_reference/utils.py | 14 ++ .../templates/meta-reference-gpu/run.yaml | 2 +- 5 files changed, 115 insertions(+), 163 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index d1e4d30e7..7338a6465 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -41,7 +41,7 @@ class TrainingConfig(BaseModel): gradient_accumulation_steps: int batch_size: int shuffle: bool - # n_iters: int + optimizer_config: OptimizerConfig enable_activation_checkpointing: bool memory_efficient_fsdp_wrap: Optional[bool] @@ -63,6 +63,7 @@ class LoraFinetuningConfig(BaseModel): apply_lora_to_output: bool rank: int alpha: int + use_dora: bool @json_schema_type @@ -116,7 +117,6 @@ class PostTrainingSFTRequest(BaseModel): algorithm: FinetuningAlgorithm algorithm_config: LoraFinetuningConfig - optimizer_config: OptimizerConfig training_config: TrainingConfig # TODO: define these @@ -178,7 +178,7 @@ class PostTrainingJobArtifactsResponse(BaseModel): class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune") - def supervised_fine_tune( + async def supervised_fine_tune( self, job_uuid: str, model: str, @@ -186,14 +186,14 @@ class PostTraining(Protocol): validation_dataset_id: str, algorithm: FinetuningAlgorithm, algorithm_config: LoraFinetuningConfig, - optimizer_config: OptimizerConfig, + # optimizer_config: OptimizerConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") - def preference_optimize( + async def preference_optimize( self, job_uuid: str, finetuned_model: URL, @@ -208,21 +208,23 @@ class PostTraining(Protocol): ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs") - def get_training_jobs(self) -> List[PostTrainingJob]: ... + async def get_training_jobs(self) -> List[PostTrainingJob]: ... # sends SSE stream of logs @webmethod(route="/post-training/job/logs") - def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... + async def get_training_job_logstream( + self, job_uuid: str + ) -> PostTrainingJobLogStream: ... @webmethod(route="/post-training/job/status") - def get_training_job_status( + async def get_training_job_status( self, job_uuid: str ) -> PostTrainingJobStatusResponse: ... @webmethod(route="/post-training/job/cancel") - def cancel_training_job(self, job_uuid: str) -> None: ... + async def cancel_training_job(self, job_uuid: str) -> None: ... @webmethod(route="/post-training/job/artifacts") - def get_training_job_artifacts( + async def get_training_job_artifacts( self, job_uuid: str ) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/providers/inline/post_training/meta_reference/post_training.py b/llama_stack/providers/inline/post_training/meta_reference/post_training.py index ca5ef67e5..2ff8de381 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/post_training.py +++ b/llama_stack/providers/inline/post_training/meta_reference/post_training.py @@ -20,7 +20,7 @@ class MetaReferencePostTrainingImpl: self.config = config self.datasetio_api = datasetio_api - def supervised_fine_tune( + async def supervised_fine_tune( self, job_uuid: str, model: str, @@ -28,11 +28,11 @@ class MetaReferencePostTrainingImpl: validation_dataset_id: str, algorithm: FinetuningAlgorithm, algorithm_config: LoraFinetuningConfig, - optimizer_config: OptimizerConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], ) -> PostTrainingJob: + # wrapper request to make it easier to pass around (internal only, not exposed to API) request = PostTrainingSFTRequest( job_uuid=job_uuid, @@ -41,7 +41,6 @@ class MetaReferencePostTrainingImpl: validation_dataset_id=validation_dataset_id, algorithm=algorithm, algorithm_config=algorithm_config, - optimizer_config=optimizer_config, training_config=training_config, hyperparam_search_config=hyperparam_search_config, logger_config=logger_config, @@ -50,14 +49,14 @@ class MetaReferencePostTrainingImpl: recipe = LoraFinetuningSingleDevice( self.config, request, self.datasetio_api ) - recipe.setup(self.config) - recipe.train() + await recipe.setup(self.config) + await recipe.train() else: raise NotImplementedError() return PostTrainingJob(job_uuid=job_uuid) - def preference_optimize( + async def preference_optimize( self, job_uuid: str, finetuned_model: URL, @@ -71,21 +70,24 @@ class MetaReferencePostTrainingImpl: logger_config: Dict[str, Any], ) -> PostTrainingJob: ... - def get_training_jobs(self) -> List[PostTrainingJob]: ... + # TODO @markchen1015 impelment below APIs + async def get_training_jobs(self) -> List[PostTrainingJob]: ... # sends SSE stream of logs @webmethod(route="/post-training/job/logs") - def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... + async def get_training_job_logstream( + self, job_uuid: str + ) -> PostTrainingJobLogStream: ... @webmethod(route="/post-training/job/status") - def get_training_job_status( + async def get_training_job_status( self, job_uuid: str ) -> PostTrainingJobStatusResponse: ... @webmethod(route="/post-training/job/cancel") - def cancel_training_job(self, job_uuid: str) -> None: ... + async def cancel_training_job(self, job_uuid: str) -> None: ... @webmethod(route="/post-training/job/artifacts") - def get_training_job_artifacts( + async def get_training_job_artifacts( self, job_uuid: str ) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py index 047d2805a..94951d121 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py @@ -7,15 +7,20 @@ import asyncio import logging import os +import re from functools import partial +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import torch +from llama_models.sku_list import resolve_model from llama_stack.apis.datasetio import DatasetIO from torch import nn from torchtune import utils as torchtune_utils +from torchtune.training.checkpointing._utils import ModelType from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import PostTrainingSFTRequest +from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.inline.post_training.meta_reference import utils from llama_stack.providers.inline.post_training.meta_reference.config import ( @@ -47,14 +52,24 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer class LoraFinetuningSingleDevice: + # This recipe only supports GPU training + + # This recipe doesn't include several training efficiency setting within origin torchtune repo, including + # - compile + # - activation offloading + + # Resume from checkpoint hasn't been supported yet + # Validation hasn't been supported yet + + # TODO @markchen1015 figure out the logging for this training recipe + # and make it work with telemetry def __init__( self, config: MetaReferencePostTrainingConfig, request: PostTrainingSFTRequest, datasetio_api: DatasetIO, ) -> None: - # to make user config easier, assume the device is 'cuda' only - # self._device = utils.get_device(device=cfg.device) + # Assume the training only happens on GPU self.config = config self.request = request self._device = torchtune_utils.get_device(device="cuda") @@ -63,11 +78,30 @@ class LoraFinetuningSingleDevice: ) self.model_id = config.model - # hardcode it for now and see how it works with get_training_job_artifacts - self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}" + def model_checkpoint_dir(model) -> str: + checkpoint_dir = Path(model_local_dir(model.descriptor())) - self._log_every_n_steps = 1 - self._log_peak_memory_stats = False + paths = [ + Path(checkpoint_dir / f"consolidated.{ext}") + for ext in ["pth", "00.pth"] + ] + if not any(p.exists() for p in paths): + checkpoint_dir = checkpoint_dir / "original" + + assert checkpoint_dir.exists(), ( + f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. " + f"Please download model using `llama download --model-id {model.descriptor()}`" + ) + return str(checkpoint_dir) + + if config.checkpoint_dir and config.checkpoint_dir != "null": + self.checkpoint_dir = config.checkpoint_dir + else: + model = resolve_model(self.model_id) + self.checkpoint_dir = model_checkpoint_dir(model) + + # TODO @markchen1015 make it work with get_training_job_artifacts + self._output_dir = self.checkpoint_dir + "/posting_training/" self.seed = training.set_seed(seed=config.torch_seed or 42) self.epochs_run = 0 @@ -75,23 +109,15 @@ class LoraFinetuningSingleDevice: self._shuffle = request.training_config.shuffle self._batch_size = request.training_config.batch_size - self.checkpoint_dir = ( - self.config.checkpoint_dir or f"~/.llama/checkpoints/{self.model_id}" - ) - # this is important for debugging purpose self.max_steps_per_epoch = request.training_config.max_steps_per_epoch self.global_step = 0 - # not needed in MVP - # self._resume_from_checkpoint = cfg.resume_from_checkpoint - # self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) - self._gradient_accumulation_steps = ( request.training_config.gradient_accumulation_steps ) - self._clip_grad_norm = 1.0 # hardcode + self._clip_grad_norm = 1.0 self._enable_activation_checkpointing = ( request.training_config.enable_activation_checkpointing ) @@ -99,12 +125,11 @@ class LoraFinetuningSingleDevice: self.datasetio_api = datasetio_api - def load_checkpoint(self): + async def load_checkpoint(self): def get_checkpoint_files(checkpoint_dir: str) -> List[str]: try: # List all files in the given directory files = os.listdir(checkpoint_dir) - # Filter files that end with .pth pth_files = [file for file in files if file.endswith(".pth")] return pth_files @@ -115,44 +140,40 @@ class LoraFinetuningSingleDevice: checkpoint_dir=self.checkpoint_dir, checkpoint_files=get_checkpoint_files(self.checkpoint_dir), output_dir=self._output_dir, - # todo: automatically get this info from model - model_type="LLAMA3", + model_type=utils.get_checkpointer_model_type(self.model_id), ) checkpoint_dict = self._checkpointer.load_checkpoint() return checkpoint_dict - def setup(self, config: MetaReferencePostTrainingConfig) -> None: - # todo: figure out how does it works with telemetry - # self._metric_logger = config.instantiate(cfg.metric_logger) - # self._metric_logger.log_config(cfg) + async def setup(self, config: MetaReferencePostTrainingConfig) -> None: + checkpoint_dict = await self.load_checkpoint() - checkpoint_dict = self.load_checkpoint() - - # hack to toggle to the low cpu ram version of the reparametrize_as_dtype - # hook based on the config. - # common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False) - - # set up model - self._model = self._setup_model( + self._model = await self._setup_model( enable_activation_checkpointing=self._enable_activation_checkpointing, enable_activation_offloading=self._enable_activation_offloading, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=None, ) + log.info(f"Model is initialized with precision {self._dtype}.") - self._tokenizer = self._setup_tokenizer() + self._tokenizer = await self._setup_tokenizer() log.info("Tokenizer is initialized from file.") - self._optimizer = self._setup_optimizer( - optimizer_config=self.request.training_config.optimizer + self._optimizer = await self._setup_optimizer( + optimizer_config=self.request.training_config.optimizer_config ) + log.info("Optimizer is initialized.") self._loss_fn = CEWithChunkedOutputLoss() - self._sampler, self._dataloader = self._setup_data( + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + log.info("Loss is initialized.") + + self._sampler, self._dataloader = await self._setup_data( tokenizer=self._tokenizer, shuffle=self._shuffle, batch_size=self._batch_size, ) + log.info("Dataset and Sampler are initialized.") # Number of training steps in each epoch depends on the number of batches produced # by the dataloader and the max_steps_per_epoch param set by the user and is used @@ -170,18 +191,19 @@ class LoraFinetuningSingleDevice: # Learning rate scheduler can only be set up after number of steps # has been computed - self._lr_scheduler = self._setup_lr_scheduler( - num_warmup_steps=self.request.optimizer_config.num_warmup_steps, + self._lr_scheduler = await self._setup_lr_scheduler( + num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps, num_training_steps=self.total_epochs * self._steps_per_epoch, last_epoch=self.global_step - 1, ) + log.info("Learning rate scheduler is initialized.") # Used to ignore labels for loss computation self.ignore_labels_cache = torch.full( (self._batch_size, 1), self._loss_fn.ignore_index, device=self._device ) - def _setup_model( + async def _setup_model( self, enable_activation_checkpointing: bool, enable_activation_offloading: bool, @@ -243,9 +265,8 @@ class LoraFinetuningSingleDevice: lora_missing=lora_missing, lora_unexpected=lora_unexpected, ) + # Validate model adapter params were loaded in with the expected dtype - # TODO (rohan-varma): Further validation to ensure the appropriate base params - # are NF4 vs bf16 based on the quantization config. training.validate_expected_param_dtype( self.adapter_params.items(), dtype=self._dtype ) @@ -254,22 +275,16 @@ class LoraFinetuningSingleDevice: self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) - - log.info(f"Model is initialized with precision {self._dtype}.") - - # if self._device.type != "cpu": - # memory_stats = training.get_memory_stats(device=self._device) - # training.log_memory_stats(memory_stats) return model - def _setup_tokenizer( + async def _setup_tokenizer( self, ) -> Llama3Tokenizer: tokenizer_path = self.checkpoint_dir + "/tokenizer.model" tokenizer_type = utils.get_tokenizer_type(self.model_id) return tokenizer_type(path=tokenizer_path) - def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer: + async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer: optimizer = torch.optim.AdamW( params=self._model.parameters(), lr=optimizer_config.lr, @@ -277,11 +292,9 @@ class LoraFinetuningSingleDevice: eps=1e-8, weight_decay=0.1, ) - - log.info("Optimizer and loss are initialized.") return optimizer - def _setup_data( + async def _setup_data( self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int ) -> Tuple[DistributedSampler, DataLoader]: async def fetch_rows(): @@ -290,10 +303,11 @@ class LoraFinetuningSingleDevice: rows_in_page=-1, ) - # Run the async function in an event loop - all_rows = asyncio.run(fetch_rows()) + all_rows = await fetch_rows() rows = all_rows.rows + # Curretly only support instruct dataset + # TODO @markchen1015 make the message_transform swappable and support more dataset types ds = SFTDataset( rows, message_transform=InputOutputToMessages(), model_transform=tokenizer ) @@ -320,11 +334,9 @@ class LoraFinetuningSingleDevice: ), ) - log.info("Dataset and Sampler are initialized.") - return sampler, dataloader - def _setup_lr_scheduler( + async def _setup_lr_scheduler( self, num_warmup_steps: int, num_training_steps: int, @@ -332,33 +344,19 @@ class LoraFinetuningSingleDevice: ) -> Optimizer: lr_scheduler = get_cosine_schedule_with_warmup( self._optimizer, + num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch, ) - - log.info("Learning rate scheduler is initialized.") return lr_scheduler - def save_checkpoint(self, epoch: int) -> None: - """ - Checkpoint the state of the recipe. The constructed checkpoint state dict - contains the following information: - - Merged weights with key MODEL_KEY - - Adapter weights with key ADAPTER_KEY - - Relevant recipe state if training is not complete - - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - - To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. - """ + async def save_checkpoint(self, epoch: int) -> None: ckpt_dict = {} - intermediate_checkpoint = epoch + 1 < self.total_epochs - adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) # Construct the full state dict with LoRA weights merged into base LLM weights - # Move to CPU to avoid a copy on GPU state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} @@ -385,10 +383,9 @@ class LoraFinetuningSingleDevice: self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, - intermediate_checkpoint=intermediate_checkpoint, ) - def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Shape [b, s], needed for the loss not the model labels = batch.pop("labels") # run model @@ -412,16 +409,10 @@ class LoraFinetuningSingleDevice: return loss - def train(self) -> None: + async def train(self) -> None: """ The core training loop. """ - - # if self._compile: - # log.info( - # "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." - # ) - # Initialize tokens count and running loss (for grad accumulation) # t0 = time.perf_counter() running_loss = 0 @@ -433,7 +424,6 @@ class LoraFinetuningSingleDevice: # in case shuffle is True self._sampler.set_epoch(curr_epoch) - # pbar = tqdm(total=self._steps_per_epoch) for idx, batch in enumerate(self._dataloader): if ( self.max_steps_per_epoch is not None @@ -442,14 +432,6 @@ class LoraFinetuningSingleDevice: ): break - # Start tracking CUDA memory for active steps for just the first epoch - # if ( - # curr_epoch == 0 - # and self.profiler_profile_memory - # and idx == self.profiler_wait_steps + self.profiler_warmup_steps - # ): - # torch.cuda.memory._record_memory_history() - torchtune_utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch @@ -461,14 +443,14 @@ class LoraFinetuningSingleDevice: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients - current_loss = self._loss_step(batch) * current_num_tokens + current_loss = await self._loss_step(batch) * current_num_tokens running_loss += current_loss current_loss.backward() # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: training.scale_grads(self._model, 1 / num_tokens) - grad_norm = torch.nn.utils.clip_grad_norm_( + torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), ) @@ -478,58 +460,10 @@ class LoraFinetuningSingleDevice: # Update the number of steps when the weights are updated self.global_step += 1 - # loss_to_log = running_loss.item() / num_tokens - # pbar.update(1) - # pbar.set_description( - # f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" - # ) - - # Log per-step metrics - # if self.global_step % self._log_every_n_steps == 0: - # time_per_step = time.perf_counter() - t0 - # log_dict = { - # "loss": loss_to_log, - # "lr": self._optimizer.param_groups[0]["lr"], - # "tokens_per_second_per_gpu": num_tokens / time_per_step, - # } - # if self._device.type == "cuda" and self._log_peak_memory_stats: - # log_dict.update( - # training.get_memory_stats(device=self._device) - # ) - # if self._clip_grad_norm is not None: - # log_dict.update({"grad_norm": grad_norm}) - # self._metric_logger.log_dict( - # log_dict, - # step=self.global_step, - # ) - # Reset running stats for the next step running_loss = 0 num_tokens = 0 - # t0 = time.perf_counter() - - # Stop tracking CUDA memory now that active steps are complete - # if ( - # curr_epoch == 0 - # and self.profiler_profile_memory - # and idx - # == self.profiler_wait_steps - # + self.profiler_warmup_steps - # + self.profiler_active_steps - # ): - # torch.cuda.memory._record_memory_history(enabled=None) - - # Step the profiler - # Note we are stepping each batch, which might not include optimizer step in the trace - # if the schedule cycle doesn't align with gradient accumulation. - # prof.step() self.epochs_run += 1 - # start_save_checkpoint = time.perf_counter() log.info("Starting checkpoint save...") - self.save_checkpoint(epoch=curr_epoch) - # log.info( - # "Checkpoint saved in {:.2f} seconds.".format( - # time.perf_counter() - start_save_checkpoint - # ) - # ) + await self.save_checkpoint(epoch=curr_epoch) diff --git a/llama_stack/providers/inline/post_training/meta_reference/utils.py b/llama_stack/providers/inline/post_training/meta_reference/utils.py index 4db5ab6df..70280081f 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/utils.py +++ b/llama_stack/providers/inline/post_training/meta_reference/utils.py @@ -16,15 +16,22 @@ import torch from llama_models.sku_list import resolve_model from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer +from torchtune.models.llama3_2 import lora_llama3_2_3b LORA_MODEL_TYPES: Dict[str, Any] = { + "Llama3.2-3B-Instruct": lora_llama3_2_3b, "Llama-3-8B-Instruct": lora_llama3_8b, } TOKENIZER_TYPES: Dict[str, Any] = { + "Llama3.2-3B-Instruct": llama3_tokenizer, "Llama-3-8B-Instruct": llama3_tokenizer, } +CHECKPOINT_MODEL_TYPES: Dict[str, str] = { + "Llama3.2-3B-Instruct": "LLAMA3_2", +} + BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildTokenizerCallable = Callable[..., Llama3Tokenizer] @@ -41,3 +48,10 @@ def get_tokenizer_type( ) -> BuildTokenizerCallable: model = resolve_model(model_id) return TOKENIZER_TYPES[model.core_model_id.value] + + +def get_checkpointer_model_type( + model_id: str, +) -> str: + model = resolve_model(model_id) + return CHECKPOINT_MODEL_TYPES[model.core_model_id.value] diff --git a/llama_stack/templates/meta-reference-gpu/run.yaml b/llama_stack/templates/meta-reference-gpu/run.yaml index 8cd71b7b1..f19aa180e 100644 --- a/llama_stack/templates/meta-reference-gpu/run.yaml +++ b/llama_stack/templates/meta-reference-gpu/run.yaml @@ -71,7 +71,7 @@ datasets: uri: https://huggingface.co/datasets/tatsu-lab/alpaca metadata: path: tatsu-lab/alpaca - name: post_training_alpaca + name: split: train dataset_schema: instruction: