From 12eef585432c6b3adeb4a213bb98ec09c8f1929c Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 4 Dec 2024 15:19:54 -0800 Subject: [PATCH] address comment --- .../apis/post_training/post_training.py | 20 +++--- .../inline/post_training/torchtune/config.py | 9 +-- .../post_training/torchtune/post_training.py | 42 ++++-------- .../recipes/lora_finetuning_single_device.py | 68 +++++++++---------- .../providers/registry/post_training.py | 2 +- .../experimental-post-training/run.yaml | 4 +- 6 files changed, 58 insertions(+), 87 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 9202e1753..f593749e2 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -7,7 +7,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Protocol +from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod @@ -62,13 +62,6 @@ class TrainingConfig(BaseModel): dtype: Optional[str] = "bf16" -@json_schema_type -class FinetuningAlgorithm(Enum): - full = "full" - lora = "lora" - qat = "qat" - - @json_schema_type class LoraFinetuningConfig(BaseModel): lora_attn_modules: List[str] @@ -172,12 +165,17 @@ class PostTraining(Protocol): async def supervised_fine_tune( self, job_uuid: str, - model: str, - algorithm: FinetuningAlgorithm, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], - algorithm_config: Optional[LoraFinetuningConfig] = None, + model: str = Field( + default="Llama3.2-3B-Instruct", + description="Model descriptor from `llama model list`", + ), + checkpoint_dir: Optional[str] = None, + algorithm_config: Optional[ + Union[LoraFinetuningConfig, QATFinetuningConfig] + ] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") diff --git a/llama_stack/providers/inline/post_training/torchtune/config.py b/llama_stack/providers/inline/post_training/torchtune/config.py index 7bfea0f9d..3ffa55c70 100644 --- a/llama_stack/providers/inline/post_training/torchtune/config.py +++ b/llama_stack/providers/inline/post_training/torchtune/config.py @@ -6,15 +6,8 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel class TorchtunePostTrainingConfig(BaseModel): - model: str = Field( - default="Llama3.2-3B-Instruct", - description="Model descriptor from `llama model list`", - ) torch_seed: Optional[int] = None - # By default, the implementation will look at ~/.llama/checkpoints/ but you - # can override by specifying the directory explicitly - checkpoint_dir: Optional[str] = None diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 83a8ef02f..74124325a 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -13,18 +13,6 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin ) -class PostTrainingSFTRequest(BaseModel): - job_uuid: str - model: str - algorithm: FinetuningAlgorithm - algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None - training_config: TrainingConfig - - # TODO: define these - hyperparam_search_config: Dict[str, Any] - logger_config: Dict[str, Any] - - class TorchtunePostTrainingImpl: def __init__( self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO @@ -35,29 +23,25 @@ class TorchtunePostTrainingImpl: async def supervised_fine_tune( self, job_uuid: str, - model: str, - algorithm: FinetuningAlgorithm, - algorithm_config: LoraFinetuningConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], + model: str, + checkpoint_dir: Optional[str], + algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], ) -> PostTrainingJob: - - # wrapper request to make it easier to pass around (internal only, not exposed to API) - request = PostTrainingSFTRequest( - job_uuid=job_uuid, - model=model, - algorithm=algorithm, - algorithm_config=algorithm_config, - training_config=training_config, - hyperparam_search_config=hyperparam_search_config, - logger_config=logger_config, - ) - if request.algorithm == FinetuningAlgorithm.lora: + if isinstance(algorithm_config, LoraFinetuningConfig): recipe = LoraFinetuningSingleDevice( - self.config, request, self.datasetio_api + self.config, + training_config, + hyperparam_search_config, + logger_config, + model, + checkpoint_dir, + algorithm_config, + self.datasetio_api, ) - await recipe.setup(self.config) + await recipe.setup() await recipe.train() else: raise NotImplementedError() diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 0603347bb..ce8f10503 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -22,12 +22,9 @@ from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.inline.post_training.torchtune import utils from llama_stack.providers.inline.post_training.torchtune.config import ( - MetaReferencePostTrainingConfig, + TorchtunePostTrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset -from llama_stack.providers.inline.post_training.torchtune.post_training import ( - PostTrainingSFTRequest, -) from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import modules, training @@ -64,18 +61,21 @@ class LoraFinetuningSingleDevice: # and make it work with telemetry def __init__( self, - config: MetaReferencePostTrainingConfig, - request: PostTrainingSFTRequest, + config: TorchtunePostTrainingConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], + model: str, + checkpoint_dir: Optional[str], + algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], datasetio_api: DatasetIO, ) -> None: # Assume the training only happens on GPU - self.config = config - self.request = request + self.training_config = training_config + self.algorithm_config = algorithm_config self._device = torchtune_utils.get_device(device="cuda") - self._dtype = training.get_dtype( - request.training_config.dtype, device=self._device - ) - self.model_id = config.model + self._dtype = training.get_dtype(training_config.dtype, device=self._device) + self.model_id = model def model_checkpoint_dir(model) -> str: checkpoint_dir = Path(model_local_dir(model.descriptor())) @@ -93,7 +93,7 @@ class LoraFinetuningSingleDevice: ) return str(checkpoint_dir) - if config.checkpoint_dir and config.checkpoint_dir != "null": + if checkpoint_dir and checkpoint_dir != "null": self.checkpoint_dir = config.checkpoint_dir else: model = resolve_model(self.model_id) @@ -102,29 +102,27 @@ class LoraFinetuningSingleDevice: # TODO @markchen1015 make it work with get_training_job_artifacts self._output_dir = self.checkpoint_dir + "/posting_training/" - self.seed = training.set_seed(seed=config.torch_seed or 42) + self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 - self.total_epochs = request.training_config.n_epochs - self._shuffle = request.training_config.data_config.shuffle - self._batch_size = request.training_config.data_config.batch_size + self.total_epochs = training_config.n_epochs + self._shuffle = training_config.data_config.shuffle + self._batch_size = training_config.data_config.batch_size # this is important for debugging purpose - self.max_steps_per_epoch = request.training_config.max_steps_per_epoch + self.max_steps_per_epoch = training_config.max_steps_per_epoch self.global_step = 0 - self._gradient_accumulation_steps = ( - request.training_config.gradient_accumulation_steps - ) + self._gradient_accumulation_steps = training_config.gradient_accumulation_steps self._clip_grad_norm = 1.0 self._enable_activation_checkpointing = ( - (request.training_config.efficiency_config.enable_activation_checkpointing) - if request.training_config.efficiency_config + (training_config.efficiency_config.enable_activation_checkpointing) + if training_config.efficiency_config else False ) self._enable_activation_offloading = ( - (request.training_config.efficiency_config.enable_activation_offloading) - if request.training_config.efficiency_config + (training_config.efficiency_config.enable_activation_offloading) + if training_config.efficiency_config else False ) @@ -150,7 +148,7 @@ class LoraFinetuningSingleDevice: checkpoint_dict = self._checkpointer.load_checkpoint() return checkpoint_dict - async def setup(self, config: MetaReferencePostTrainingConfig) -> None: + async def setup(self) -> None: # temporily log to local disk, will figure out how to interop with telemetry self._metric_logger = DiskLogger(log_dir=self._output_dir) @@ -168,7 +166,7 @@ class LoraFinetuningSingleDevice: log.info("Tokenizer is initialized from file.") self._optimizer = await self._setup_optimizer( - optimizer_config=self.request.training_config.optimizer_config + optimizer_config=self.training_config.optimizer_config ) log.info("Optimizer is initialized.") @@ -200,7 +198,7 @@ class LoraFinetuningSingleDevice: # Learning rate scheduler can only be set up after number of steps # has been computed self._lr_scheduler = await self._setup_lr_scheduler( - num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps, + num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps, num_training_steps=self.total_epochs * self._steps_per_epoch, last_epoch=self.global_step - 1, ) @@ -218,12 +216,12 @@ class LoraFinetuningSingleDevice: base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: - self._lora_rank = self.request.algorithm_config.rank - self._lora_alpha = self.request.algorithm_config.alpha - self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules) - self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp - self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output - self._use_dora = self.request.algorithm_config.use_dora or False + self._lora_rank = self.algorithm_config.rank + self._lora_alpha = self.algorithm_config.alpha + self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules) + self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp + self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output + self._use_dora = self.algorithm_config.use_dora or False with training.set_default_dtype(self._dtype), self._device: model_type = utils.get_model_type(self.model_id) @@ -311,7 +309,7 @@ class LoraFinetuningSingleDevice: ) -> Tuple[DistributedSampler, DataLoader]: async def fetch_rows(): return await self.datasetio_api.get_rows_paginated( - dataset_id=self.request.training_config.data_config.dataset_id, + dataset_id=self.training_config.data_config.dataset_id, rows_in_page=-1, ) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index fc4b93c40..2c9fdd43d 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -16,7 +16,7 @@ def available_providers() -> List[ProviderSpec]: provider_type="inline::torchtune", pip_packages=["torch", "torchtune", "torchao", "numpy"], module="llama_stack.providers.inline.post_training.torchtune", - config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig", + config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig", api_dependencies=[ Api.datasetio, ], diff --git a/llama_stack/templates/experimental-post-training/run.yaml b/llama_stack/templates/experimental-post-training/run.yaml index e50280401..3cda9c062 100644 --- a/llama_stack/templates/experimental-post-training/run.yaml +++ b/llama_stack/templates/experimental-post-training/run.yaml @@ -49,9 +49,7 @@ providers: post_training: - provider_id: meta-reference-post-training provider_type: inline::torchtune - config: - model: ${env.POST_TRAINING_MODEL} - checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} + config: {} metadata_store: namespace: null