address comment

This commit is contained in:
Botao Chen 2024-12-04 15:19:54 -08:00
parent 2a15a8a005
commit 12eef58543
6 changed files with 58 additions and 87 deletions

View file

@ -7,7 +7,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
@ -62,13 +62,6 @@ class TrainingConfig(BaseModel):
dtype: Optional[str] = "bf16"
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qat = "qat"
@json_schema_type
class LoraFinetuningConfig(BaseModel):
lora_attn_modules: List[str]
@ -172,12 +165,17 @@ class PostTraining(Protocol):
async def supervised_fine_tune(
self,
job_uuid: str,
model: str,
algorithm: FinetuningAlgorithm,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
algorithm_config: Optional[LoraFinetuningConfig] = None,
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
),
checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[
Union[LoraFinetuningConfig, QATFinetuningConfig]
] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")

View file

@ -6,15 +6,8 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
model: str = Field(
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
torch_seed: Optional[int] = None
# By default, the implementation will look at ~/.llama/checkpoints/<model> but you
# can override by specifying the directory explicitly
checkpoint_dir: Optional[str] = None

View file

@ -13,18 +13,6 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
)
class PostTrainingSFTRequest(BaseModel):
job_uuid: str
model: str
algorithm: FinetuningAlgorithm
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
class TorchtunePostTrainingImpl:
def __init__(
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
@ -35,29 +23,25 @@ class TorchtunePostTrainingImpl:
async def supervised_fine_tune(
self,
job_uuid: str,
model: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
) -> PostTrainingJob:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = PostTrainingSFTRequest(
job_uuid=job_uuid,
model=model,
algorithm=algorithm,
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config=hyperparam_search_config,
logger_config=logger_config,
)
if request.algorithm == FinetuningAlgorithm.lora:
if isinstance(algorithm_config, LoraFinetuningConfig):
recipe = LoraFinetuningSingleDevice(
self.config, request, self.datasetio_api
self.config,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
)
await recipe.setup(self.config)
await recipe.setup()
await recipe.train()
else:
raise NotImplementedError()

View file

@ -22,12 +22,9 @@ from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.torchtune.config import (
MetaReferencePostTrainingConfig,
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from llama_stack.providers.inline.post_training.torchtune.post_training import (
PostTrainingSFTRequest,
)
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
@ -64,18 +61,21 @@ class LoraFinetuningSingleDevice:
# and make it work with telemetry
def __init__(
self,
config: MetaReferencePostTrainingConfig,
request: PostTrainingSFTRequest,
config: TorchtunePostTrainingConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
datasetio_api: DatasetIO,
) -> None:
# Assume the training only happens on GPU
self.config = config
self.request = request
self.training_config = training_config
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(
request.training_config.dtype, device=self._device
)
self.model_id = config.model
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
self.model_id = model
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
@ -93,7 +93,7 @@ class LoraFinetuningSingleDevice:
)
return str(checkpoint_dir)
if config.checkpoint_dir and config.checkpoint_dir != "null":
if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir
else:
model = resolve_model(self.model_id)
@ -102,29 +102,27 @@ class LoraFinetuningSingleDevice:
# TODO @markchen1015 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/"
self.seed = training.set_seed(seed=config.torch_seed or 42)
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = request.training_config.n_epochs
self._shuffle = request.training_config.data_config.shuffle
self._batch_size = request.training_config.data_config.batch_size
self.total_epochs = training_config.n_epochs
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
# this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
self.max_steps_per_epoch = training_config.max_steps_per_epoch
self.global_step = 0
self._gradient_accumulation_steps = (
request.training_config.gradient_accumulation_steps
)
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
(request.training_config.efficiency_config.enable_activation_checkpointing)
if request.training_config.efficiency_config
(training_config.efficiency_config.enable_activation_checkpointing)
if training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(request.training_config.efficiency_config.enable_activation_offloading)
if request.training_config.efficiency_config
(training_config.efficiency_config.enable_activation_offloading)
if training_config.efficiency_config
else False
)
@ -150,7 +148,7 @@ class LoraFinetuningSingleDevice:
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
async def setup(self) -> None:
# temporily log to local disk, will figure out how to interop with telemetry
self._metric_logger = DiskLogger(log_dir=self._output_dir)
@ -168,7 +166,7 @@ class LoraFinetuningSingleDevice:
log.info("Tokenizer is initialized from file.")
self._optimizer = await self._setup_optimizer(
optimizer_config=self.request.training_config.optimizer_config
optimizer_config=self.training_config.optimizer_config
)
log.info("Optimizer is initialized.")
@ -200,7 +198,7 @@ class LoraFinetuningSingleDevice:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = await self._setup_lr_scheduler(
num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps,
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
@ -218,12 +216,12 @@ class LoraFinetuningSingleDevice:
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
self._lora_rank = self.request.algorithm_config.rank
self._lora_alpha = self.request.algorithm_config.alpha
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
self._use_dora = self.request.algorithm_config.use_dora or False
self._lora_rank = self.algorithm_config.rank
self._lora_alpha = self.algorithm_config.alpha
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
self._use_dora = self.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device:
model_type = utils.get_model_type(self.model_id)
@ -311,7 +309,7 @@ class LoraFinetuningSingleDevice:
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows():
return await self.datasetio_api.get_rows_paginated(
dataset_id=self.request.training_config.data_config.dataset_id,
dataset_id=self.training_config.data_config.dataset_id,
rows_in_page=-1,
)

View file

@ -16,7 +16,7 @@ def available_providers() -> List[ProviderSpec]:
provider_type="inline::torchtune",
pip_packages=["torch", "torchtune", "torchao", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.torchtunePostTrainingConfig",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
],

View file

@ -49,9 +49,7 @@ providers:
post_training:
- provider_id: meta-reference-post-training
provider_type: inline::torchtune
config:
model: ${env.POST_TRAINING_MODEL}
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
config: {}
metadata_store:
namespace: null