From 41cf2bb0a7223c4c7cdf7b4048dbd68a3f9cad65 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Tue, 3 Dec 2024 20:01:27 -0800 Subject: [PATCH] refine api --- .../apis/post_training/post_training.py | 58 ++++++++++--------- .../meta_reference/post_training.py | 4 -- .../recipes/lora_finetuning_single_device.py | 52 ++++++++++++++--- 3 files changed, 74 insertions(+), 40 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 9882adfaf..63df97c68 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod @@ -18,42 +18,55 @@ from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403 +@json_schema_type class OptimizerType(Enum): adam = "adam" adamw = "adamw" sgd = "sgd" +@json_schema_type +class DataConfig(BaseModel): + dataset_id: str + batch_size: int + shuffle: bool + validation_dataset_id: Optional[str] = None + packed: Optional[bool] = False + train_on_input: Optional[bool] = False + + @json_schema_type class OptimizerConfig(BaseModel): optimizer_type: OptimizerType lr: float - lr_min: float weight_decay: float num_warmup_steps: int +@json_schema_type +class EfficiencyConfig(BaseModel): + enable_activation_checkpointing: Optional[bool] = False + enable_activation_offloading: Optional[bool] = False + memory_efficient_fsdp_wrap: Optional[bool] = False + fsdp_cpu_offload: Optional[bool] = False + + @json_schema_type class TrainingConfig(BaseModel): - dtype: str n_epochs: int max_steps_per_epoch: int gradient_accumulation_steps: int - batch_size: int - shuffle: bool + data_config: DataConfig optimizer_config: OptimizerConfig - - enable_activation_checkpointing: bool - memory_efficient_fsdp_wrap: Optional[bool] - fsdp_cpu_offload: Optional[bool] + efficiency_config: Optional[EfficiencyConfig] = None + dtype: Optional[str] = "bf16" @json_schema_type class FinetuningAlgorithm(Enum): full = "full" lora = "lora" - qlora = "qlora" - dora = "dora" + qat = "qat" @json_schema_type @@ -63,17 +76,14 @@ class LoraFinetuningConfig(BaseModel): apply_lora_to_output: bool rank: int alpha: int - use_dora: bool + use_dora: Optional[bool] = False + quantize_base: Optional[bool] = False @json_schema_type -class QLoraFinetuningConfig(LoraFinetuningConfig): - pass - - -@json_schema_type -class DoraFinetuningConfig(LoraFinetuningConfig): - pass +class QATFinetuningConfig(BaseModel): + quantizer_name: str + group_size: int @json_schema_type @@ -110,13 +120,9 @@ class PostTrainingSFTRequest(BaseModel): """Request to finetune a model.""" job_uuid: str - model: str - dataset_id: str - validation_dataset_id: str - algorithm: FinetuningAlgorithm - algorithm_config: LoraFinetuningConfig + algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None training_config: TrainingConfig # TODO: define these @@ -182,13 +188,11 @@ class PostTraining(Protocol): self, job_uuid: str, model: str, - dataset_id: str, - validation_dataset_id: str, algorithm: FinetuningAlgorithm, - algorithm_config: LoraFinetuningConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], + algorithm_config: Optional[LoraFinetuningConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") diff --git a/llama_stack/providers/inline/post_training/meta_reference/post_training.py b/llama_stack/providers/inline/post_training/meta_reference/post_training.py index 2ff8de381..8ab98f7d4 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/post_training.py +++ b/llama_stack/providers/inline/post_training/meta_reference/post_training.py @@ -24,8 +24,6 @@ class MetaReferencePostTrainingImpl: self, job_uuid: str, model: str, - dataset_id: str, - validation_dataset_id: str, algorithm: FinetuningAlgorithm, algorithm_config: LoraFinetuningConfig, training_config: TrainingConfig, @@ -37,8 +35,6 @@ class MetaReferencePostTrainingImpl: request = PostTrainingSFTRequest( job_uuid=job_uuid, model=model, - dataset_id=dataset_id, - validation_dataset_id=validation_dataset_id, algorithm=algorithm, algorithm_config=algorithm_config, training_config=training_config, diff --git a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py index c6677d6d2..30b315329 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py @@ -6,6 +6,7 @@ import logging import os +import time from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Tuple @@ -15,6 +16,7 @@ from llama_models.sku_list import resolve_model from llama_stack.apis.datasetio import DatasetIO from torch import nn from torchtune import utils as torchtune_utils +from torchtune.training.metric_logging import DiskLogger from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import PostTrainingSFTRequest from llama_stack.distribution.utils.model_utils import model_local_dir @@ -29,7 +31,7 @@ from llama_stack.providers.inline.post_training.meta_reference.datasets.sft impo from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training -from torchtune.data import InputOutputToMessages, padded_collate_sft +from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( @@ -103,8 +105,8 @@ class LoraFinetuningSingleDevice: self.seed = training.set_seed(seed=config.torch_seed or 42) self.epochs_run = 0 self.total_epochs = request.training_config.n_epochs - self._shuffle = request.training_config.shuffle - self._batch_size = request.training_config.batch_size + self._shuffle = request.training_config.data_config.shuffle + self._batch_size = request.training_config.data_config.batch_size # this is important for debugging purpose self.max_steps_per_epoch = request.training_config.max_steps_per_epoch @@ -116,9 +118,15 @@ class LoraFinetuningSingleDevice: self._clip_grad_norm = 1.0 self._enable_activation_checkpointing = ( - request.training_config.enable_activation_checkpointing + (request.training_config.efficiency_config.enable_activation_checkpointing) + if request.training_config.efficiency_config + else False + ) + self._enable_activation_offloading = ( + (request.training_config.efficiency_config.enable_activation_offloading) + if request.training_config.efficiency_config + else False ) - self._enable_activation_offloading = False self.datasetio_api = datasetio_api @@ -143,6 +151,9 @@ class LoraFinetuningSingleDevice: return checkpoint_dict async def setup(self, config: MetaReferencePostTrainingConfig) -> None: + # temporily log to local disk, will figure out how to interop with telemetry + self._metric_logger = DiskLogger(log_dir=self._output_dir) + checkpoint_dict = await self.load_checkpoint() self._model = await self._setup_model( @@ -212,7 +223,7 @@ class LoraFinetuningSingleDevice: self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules) self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output - self._use_dora = self.request.algorithm_config.use_dora + self._use_dora = self.request.algorithm_config.use_dora or False with training.set_default_dtype(self._dtype), self._device: model_type = utils.get_model_type(self.model_id) @@ -272,6 +283,10 @@ class LoraFinetuningSingleDevice: self.activations_handling_ctx = training.get_act_offloading_ctx_manager( model, enable_activation_offloading ) + + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + return model async def _setup_tokenizer( @@ -296,7 +311,7 @@ class LoraFinetuningSingleDevice: ) -> Tuple[DistributedSampler, DataLoader]: async def fetch_rows(): return await self.datasetio_api.get_rows_paginated( - dataset_id=self.request.dataset_id, + dataset_id=self.request.training_config.data_config.dataset_id, rows_in_page=-1, ) @@ -306,7 +321,9 @@ class LoraFinetuningSingleDevice: # Curretly only support instruct dataset # TODO @markchen1015 make the message_transform swappable and support more dataset types ds = SFTDataset( - rows, message_transform=InputOutputToMessages(), model_transform=tokenizer + rows, + message_transform=AlpacaToMessages(train_on_input=False), + model_transform=tokenizer, ) sampler = DistributedSampler( @@ -412,6 +429,7 @@ class LoraFinetuningSingleDevice: """ # Initialize tokens count and running loss (for grad accumulation) # t0 = time.perf_counter() + t0 = time.perf_counter() running_loss = 0 num_tokens = 0 @@ -447,7 +465,7 @@ class LoraFinetuningSingleDevice: # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: training.scale_grads(self._model, 1 / num_tokens) - torch.nn.utils.clip_grad_norm_( + grad_norm = torch.nn.utils.clip_grad_norm_( self._model.parameters(), max_norm=float(self._clip_grad_norm), ) @@ -457,9 +475,25 @@ class LoraFinetuningSingleDevice: # Update the number of steps when the weights are updated self.global_step += 1 + loss_to_log = running_loss.item() / num_tokens + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + log_dict.update(training.get_memory_stats(device=self._device)) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + # Reset running stats for the next step running_loss = 0 num_tokens = 0 + t0 = time.perf_counter() self.epochs_run += 1 log.info("Starting checkpoint save...")