refine api

This commit is contained in:
Botao Chen 2024-12-03 20:01:27 -08:00
parent 5838b7211d
commit 41cf2bb0a7
3 changed files with 74 additions and 40 deletions

View file

@ -7,7 +7,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
@ -18,42 +18,55 @@ from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type
class OptimizerType(Enum):
adam = "adam"
adamw = "adamw"
sgd = "sgd"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type
class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType
lr: float
lr_min: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type
class TrainingConfig(BaseModel):
dtype: str
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
batch_size: int
shuffle: bool
data_config: DataConfig
optimizer_config: OptimizerConfig
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: Optional[bool]
fsdp_cpu_offload: Optional[bool]
efficiency_config: Optional[EfficiencyConfig] = None
dtype: Optional[str] = "bf16"
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
qat = "qat"
@json_schema_type
@ -63,17 +76,14 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
@json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig):
pass
@json_schema_type
class DoraFinetuningConfig(LoraFinetuningConfig):
pass
class QATFinetuningConfig(BaseModel):
quantizer_name: str
group_size: int
@json_schema_type
@ -110,13 +120,9 @@ class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: str
dataset_id: str
validation_dataset_id: str
algorithm: FinetuningAlgorithm
algorithm_config: LoraFinetuningConfig
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
training_config: TrainingConfig
# TODO: define these
@ -182,13 +188,11 @@ class PostTraining(Protocol):
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
algorithm_config: Optional[LoraFinetuningConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")

View file

@ -24,8 +24,6 @@ class MetaReferencePostTrainingImpl:
self,
job_uuid: str,
model: str,
dataset_id: str,
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
training_config: TrainingConfig,
@ -37,8 +35,6 @@ class MetaReferencePostTrainingImpl:
request = PostTrainingSFTRequest(
job_uuid=job_uuid,
model=model,
dataset_id=dataset_id,
validation_dataset_id=validation_dataset_id,
algorithm=algorithm,
algorithm_config=algorithm_config,
training_config=training_config,

View file

@ -6,6 +6,7 @@
import logging
import os
import time
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -15,6 +16,7 @@ from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.distribution.utils.model_utils import model_local_dir
@ -29,7 +31,7 @@ from llama_stack.providers.inline.post_training.meta_reference.datasets.sft impo
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune.data import InputOutputToMessages, padded_collate_sft
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
@ -103,8 +105,8 @@ class LoraFinetuningSingleDevice:
self.seed = training.set_seed(seed=config.torch_seed or 42)
self.epochs_run = 0
self.total_epochs = request.training_config.n_epochs
self._shuffle = request.training_config.shuffle
self._batch_size = request.training_config.batch_size
self._shuffle = request.training_config.data_config.shuffle
self._batch_size = request.training_config.data_config.batch_size
# this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
@ -116,9 +118,15 @@ class LoraFinetuningSingleDevice:
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
request.training_config.enable_activation_checkpointing
(request.training_config.efficiency_config.enable_activation_checkpointing)
if request.training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(request.training_config.efficiency_config.enable_activation_offloading)
if request.training_config.efficiency_config
else False
)
self._enable_activation_offloading = False
self.datasetio_api = datasetio_api
@ -143,6 +151,9 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
# temporily log to local disk, will figure out how to interop with telemetry
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
@ -212,7 +223,7 @@ class LoraFinetuningSingleDevice:
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
self._use_dora = self.request.algorithm_config.use_dora
self._use_dora = self.request.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device:
model_type = utils.get_model_type(self.model_id)
@ -272,6 +283,10 @@ class LoraFinetuningSingleDevice:
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
return model
async def _setup_tokenizer(
@ -296,7 +311,7 @@ class LoraFinetuningSingleDevice:
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows():
return await self.datasetio_api.get_rows_paginated(
dataset_id=self.request.dataset_id,
dataset_id=self.request.training_config.data_config.dataset_id,
rows_in_page=-1,
)
@ -306,7 +321,9 @@ class LoraFinetuningSingleDevice:
# Curretly only support instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types
ds = SFTDataset(
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
rows,
message_transform=AlpacaToMessages(train_on_input=False),
model_transform=tokenizer,
)
sampler = DistributedSampler(
@ -412,6 +429,7 @@ class LoraFinetuningSingleDevice:
"""
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
@ -447,7 +465,7 @@ class LoraFinetuningSingleDevice:
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
torch.nn.utils.clip_grad_norm_(
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
@ -457,9 +475,25 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
log_dict.update(training.get_memory_stats(device=self._device))
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1
log.info("Starting checkpoint save...")