address comment

This commit is contained in:
Botao Chen 2024-12-04 15:19:54 -08:00
parent 2a15a8a005
commit 12eef58543
6 changed files with 58 additions and 87 deletions

View file

@ -7,7 +7,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -62,13 +62,6 @@ class TrainingConfig(BaseModel):
dtype: Optional[str] = "bf16" dtype: Optional[str] = "bf16"
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qat = "qat"
@json_schema_type @json_schema_type
class LoraFinetuningConfig(BaseModel): class LoraFinetuningConfig(BaseModel):
lora_attn_modules: List[str] lora_attn_modules: List[str]
@ -172,12 +165,17 @@ class PostTraining(Protocol):
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
model: str,
algorithm: FinetuningAlgorithm,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
algorithm_config: Optional[LoraFinetuningConfig] = None, model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[
Union[LoraFinetuningConfig, QATFinetuningConfig]
] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize") @webmethod(route="/post-training/preference-optimize")

View file

@ -6,15 +6,8 @@
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel): class TorchtunePostTrainingConfig(BaseModel):
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None torch_seed: Optional[int] = None
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None

View file

@ -13,18 +13,6 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
) )
class PostTrainingSFTRequest(BaseModel):
job_uuid: str
model: str
algorithm: FinetuningAlgorithm
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class TorchtunePostTrainingImpl: class TorchtunePostTrainingImpl:
def __init__( def __init__(
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
@ -35,29 +23,25 @@ class TorchtunePostTrainingImpl:
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
model: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
) -> PostTrainingJob: ) -> PostTrainingJob:
if isinstance(algorithm_config, LoraFinetuningConfig):
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = PostTrainingSFTRequest(
job_uuid=job_uuid,
model=model,
algorithm=algorithm,
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config=hyperparam_search_config,
logger_config=logger_config,
)
if request.algorithm == FinetuningAlgorithm.lora:
recipe = LoraFinetuningSingleDevice( recipe = LoraFinetuningSingleDevice(
self.config, request, self.datasetio_api self.config,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
) )
await recipe.setup(self.config) await recipe.setup()
await recipe.train() await recipe.train()
else: else:
raise NotImplementedError() raise NotImplementedError()

View file

@ -22,12 +22,9 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
MetaReferencePostTrainingConfig, TorchtunePostTrainingConfig,
) )
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from llama_stack.providers.inline.post_training.torchtune.post_training import (
PostTrainingSFTRequest,
)
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training from torchtune import modules, training
@ -64,18 +61,21 @@ class LoraFinetuningSingleDevice:
# and make it work with telemetry # and make it work with telemetry
def __init__( def __init__(
self, self,
config: MetaReferencePostTrainingConfig, config: TorchtunePostTrainingConfig,
request: PostTrainingSFTRequest, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO, datasetio_api: DatasetIO,
) -> None: ) -> None:
# Assume the training only happens on GPU # Assume the training only happens on GPU
self.config = config self.training_config = training_config
self.request = request self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda") self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype( self._dtype = training.get_dtype(training_config.dtype, device=self._device)
request.training_config.dtype, device=self._device self.model_id = model
)
self.model_id = config.model
def model_checkpoint_dir(model) -> str: def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor())) checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -93,7 +93,7 @@ class LoraFinetuningSingleDevice:
) )
return str(checkpoint_dir) return str(checkpoint_dir)
if config.checkpoint_dir and config.checkpoint_dir != "null": if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir self.checkpoint_dir = config.checkpoint_dir
else: else:
model = resolve_model(self.model_id) model = resolve_model(self.model_id)
@ -102,29 +102,27 @@ class LoraFinetuningSingleDevice:
# TODO @markchen1015 make it work with get_training_job_artifacts # TODO @markchen1015 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/" self._output_dir = self.checkpoint_dir + "/posting_training/"
self.seed = training.set_seed(seed=config.torch_seed or 42) self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0 self.epochs_run = 0
self.total_epochs = request.training_config.n_epochs self.total_epochs = training_config.n_epochs
self._shuffle = request.training_config.data_config.shuffle self._shuffle = training_config.data_config.shuffle
self._batch_size = request.training_config.data_config.batch_size self._batch_size = training_config.data_config.batch_size
# this is important for debugging purpose # this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch self.max_steps_per_epoch = training_config.max_steps_per_epoch
self.global_step = 0 self.global_step = 0
self._gradient_accumulation_steps = ( self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
request.training_config.gradient_accumulation_steps
)
self._clip_grad_norm = 1.0 self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = ( self._enable_activation_checkpointing = (
(request.training_config.efficiency_config.enable_activation_checkpointing) (training_config.efficiency_config.enable_activation_checkpointing)
if request.training_config.efficiency_config if training_config.efficiency_config
else False else False
) )
self._enable_activation_offloading = ( self._enable_activation_offloading = (
(request.training_config.efficiency_config.enable_activation_offloading) (training_config.efficiency_config.enable_activation_offloading)
if request.training_config.efficiency_config if training_config.efficiency_config
else False else False
) )
@ -150,7 +148,7 @@ class LoraFinetuningSingleDevice:
checkpoint_dict = self._checkpointer.load_checkpoint() checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict return checkpoint_dict
async def setup(self, config: MetaReferencePostTrainingConfig) -> None: async def setup(self) -> None:
# temporily log to local disk, will figure out how to interop with telemetry # temporily log to local disk, will figure out how to interop with telemetry
self._metric_logger = DiskLogger(log_dir=self._output_dir) self._metric_logger = DiskLogger(log_dir=self._output_dir)
@ -168,7 +166,7 @@ class LoraFinetuningSingleDevice:
log.info("Tokenizer is initialized from file.") log.info("Tokenizer is initialized from file.")
self._optimizer = await self._setup_optimizer( self._optimizer = await self._setup_optimizer(
optimizer_config=self.request.training_config.optimizer_config optimizer_config=self.training_config.optimizer_config
) )
log.info("Optimizer is initialized.") log.info("Optimizer is initialized.")
@ -200,7 +198,7 @@ class LoraFinetuningSingleDevice:
# Learning rate scheduler can only be set up after number of steps # Learning rate scheduler can only be set up after number of steps
# has been computed # has been computed
self._lr_scheduler = await self._setup_lr_scheduler( self._lr_scheduler = await self._setup_lr_scheduler(
num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps, num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
num_training_steps=self.total_epochs * self._steps_per_epoch, num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1, last_epoch=self.global_step - 1,
) )
@ -218,12 +216,12 @@ class LoraFinetuningSingleDevice:
base_model_state_dict: Dict[str, Any], base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None, lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module: ) -> nn.Module:
self._lora_rank = self.request.algorithm_config.rank self._lora_rank = self.algorithm_config.rank
self._lora_alpha = self.request.algorithm_config.alpha self._lora_alpha = self.algorithm_config.alpha
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules) self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
self._use_dora = self.request.algorithm_config.use_dora or False self._use_dora = self.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device: with training.set_default_dtype(self._dtype), self._device:
model_type = utils.get_model_type(self.model_id) model_type = utils.get_model_type(self.model_id)
@ -311,7 +309,7 @@ class LoraFinetuningSingleDevice:
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(): async def fetch_rows():
return await self.datasetio_api.get_rows_paginated( return await self.datasetio_api.get_rows_paginated(
dataset_id=self.request.training_config.data_config.dataset_id, dataset_id=self.training_config.data_config.dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )

View file

@ -16,7 +16,7 @@ def available_providers() -> List[ProviderSpec]:
provider_type="inline::torchtune", provider_type="inline::torchtune",
pip_packages=["torch", "torchtune", "torchao", "numpy"], pip_packages=["torch", "torchtune", "torchao", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune", module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig", config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
], ],

View file

@ -49,9 +49,7 @@ providers:
post_training: post_training:
- provider_id: meta-reference-post-training - provider_id: meta-reference-post-training
provider_type: inline::torchtune provider_type: inline::torchtune
config: config: {}
model: ${env.POST_TRAINING_MODEL}
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
metadata_store: metadata_store:
namespace: null namespace: null