refine api

This commit is contained in:
Botao Chen 2024-12-03 20:01:27 -08:00
parent 5838b7211d
commit 41cf2bb0a7
3 changed files with 74 additions and 40 deletions

View file

@ -7,7 +7,7 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -18,42 +18,55 @@ from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type
class OptimizerType(Enum): class OptimizerType(Enum):
adam = "adam" adam = "adam"
adamw = "adamw" adamw = "adamw"
sgd = "sgd" sgd = "sgd"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type @json_schema_type
class OptimizerConfig(BaseModel): class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType optimizer_type: OptimizerType
lr: float lr: float
lr_min: float
weight_decay: float weight_decay: float
num_warmup_steps: int num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type @json_schema_type
class TrainingConfig(BaseModel): class TrainingConfig(BaseModel):
dtype: str
n_epochs: int n_epochs: int
max_steps_per_epoch: int max_steps_per_epoch: int
gradient_accumulation_steps: int gradient_accumulation_steps: int
batch_size: int data_config: DataConfig
shuffle: bool
optimizer_config: OptimizerConfig optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None
enable_activation_checkpointing: bool dtype: Optional[str] = "bf16"
memory_efficient_fsdp_wrap: Optional[bool]
fsdp_cpu_offload: Optional[bool]
@json_schema_type @json_schema_type
class FinetuningAlgorithm(Enum): class FinetuningAlgorithm(Enum):
full = "full" full = "full"
lora = "lora" lora = "lora"
qlora = "qlora" qat = "qat"
dora = "dora"
@json_schema_type @json_schema_type
@ -63,17 +76,14 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool apply_lora_to_output: bool
rank: int rank: int
alpha: int alpha: int
use_dora: bool use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
@json_schema_type @json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig): class QATFinetuningConfig(BaseModel):
pass quantizer_name: str
group_size: int
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type @json_schema_type
@ -110,13 +120,9 @@ class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model.""" """Request to finetune a model."""
job_uuid: str job_uuid: str
model: str model: str
dataset_id: str
validation_dataset_id: str
algorithm: FinetuningAlgorithm algorithm: FinetuningAlgorithm
algorithm_config: LoraFinetuningConfig algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig training_config: TrainingConfig
# TODO: define these # TODO: define these
@ -182,13 +188,11 @@ class PostTraining(Protocol):
self, self,
job_uuid: str, job_uuid: str,
model: str, model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm, algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
algorithm_config: Optional[LoraFinetuningConfig] = None,
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize") @webmethod(route="/post-training/preference-optimize")

View file

@ -24,8 +24,6 @@ class MetaReferencePostTrainingImpl:
self, self,
job_uuid: str, job_uuid: str,
model: str, model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm, algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig, algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
@ -37,8 +35,6 @@ class MetaReferencePostTrainingImpl:
request = PostTrainingSFTRequest( request = PostTrainingSFTRequest(
job_uuid=job_uuid, job_uuid=job_uuid,
model=model, model=model,
dataset_id=dataset_id,
validation_dataset_id=validation_dataset_id,
algorithm=algorithm, algorithm=algorithm,
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
training_config=training_config, training_config=training_config,

View file

@ -6,6 +6,7 @@
import logging import logging
import os import os
import time
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -15,6 +16,7 @@ from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from torch import nn from torch import nn
from torchtune import utils as torchtune_utils from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
@ -29,7 +31,7 @@ from llama_stack.providers.inline.post_training.meta_reference.datasets.sft impo
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training from torchtune import modules, training
from torchtune.data import InputOutputToMessages, padded_collate_sft from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
@ -103,8 +105,8 @@ class LoraFinetuningSingleDevice:
self.seed = training.set_seed(seed=config.torch_seed or 42) self.seed = training.set_seed(seed=config.torch_seed or 42)
self.epochs_run = 0 self.epochs_run = 0
self.total_epochs = request.training_config.n_epochs self.total_epochs = request.training_config.n_epochs
self._shuffle = request.training_config.shuffle self._shuffle = request.training_config.data_config.shuffle
self._batch_size = request.training_config.batch_size self._batch_size = request.training_config.data_config.batch_size
# this is important for debugging purpose # this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
@ -116,9 +118,15 @@ class LoraFinetuningSingleDevice:
self._clip_grad_norm = 1.0 self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = ( self._enable_activation_checkpointing = (
request.training_config.enable_activation_checkpointing (request.training_config.efficiency_config.enable_activation_checkpointing)
if request.training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(request.training_config.efficiency_config.enable_activation_offloading)
if request.training_config.efficiency_config
else False
) )
self._enable_activation_offloading = False
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
@ -143,6 +151,9 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict return checkpoint_dict
async def setup(self, config: MetaReferencePostTrainingConfig) -> None: async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
# temporily log to local disk, will figure out how to interop with telemetry
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint() checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model( self._model = await self._setup_model(
@ -212,7 +223,7 @@ class LoraFinetuningSingleDevice:
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules) self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
self._use_dora = self.request.algorithm_config.use_dora self._use_dora = self.request.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device: with training.set_default_dtype(self._dtype), self._device:
model_type = utils.get_model_type(self.model_id) model_type = utils.get_model_type(self.model_id)
@ -272,6 +283,10 @@ class LoraFinetuningSingleDevice:
self.activations_handling_ctx = training.get_act_offloading_ctx_manager( self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading model, enable_activation_offloading
) )
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
return model return model
async def _setup_tokenizer( async def _setup_tokenizer(
@ -296,7 +311,7 @@ class LoraFinetuningSingleDevice:
) -> Tuple[DistributedSampler, DataLoader]: ) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(): async def fetch_rows():
return await self.datasetio_api.get_rows_paginated( return await self.datasetio_api.get_rows_paginated(
dataset_id=self.request.dataset_id, dataset_id=self.request.training_config.data_config.dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
@ -306,7 +321,9 @@ class LoraFinetuningSingleDevice:
# Curretly only support instruct dataset # Curretly only support instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types # TODO @markchen1015 make the message_transform swappable and support more dataset types
ds = SFTDataset( ds = SFTDataset(
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer rows,
message_transform=AlpacaToMessages(train_on_input=False),
model_transform=tokenizer,
) )
sampler = DistributedSampler( sampler = DistributedSampler(
@ -412,6 +429,7 @@ class LoraFinetuningSingleDevice:
""" """
# Initialize tokens count and running loss (for grad accumulation) # Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter() # t0 = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0 running_loss = 0
num_tokens = 0 num_tokens = 0
@ -447,7 +465,7 @@ class LoraFinetuningSingleDevice:
# Step with optimizer # Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0: if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens) training.scale_grads(self._model, 1 / num_tokens)
torch.nn.utils.clip_grad_norm_( grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(), self._model.parameters(),
max_norm=float(self._clip_grad_norm), max_norm=float(self._clip_grad_norm),
) )
@ -457,9 +475,25 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated # Update the number of steps when the weights are updated
self.global_step += 1 self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
log_dict.update(training.get_memory_stats(device=self._device))
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step # Reset running stats for the next step
running_loss = 0 running_loss = 0
num_tokens = 0 num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1 self.epochs_run += 1
log.info("Starting checkpoint save...") log.info("Starting checkpoint save...")