mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refine api
This commit is contained in:
parent
5838b7211d
commit
41cf2bb0a7
3 changed files with 74 additions and 40 deletions
|
@ -7,7 +7,7 @@
|
|||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
@ -18,42 +18,55 @@ from llama_stack.apis.datasets import * # noqa: F403
|
|||
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerType(Enum):
|
||||
adam = "adam"
|
||||
adamw = "adamw"
|
||||
sgd = "sgd"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DataConfig(BaseModel):
|
||||
dataset_id: str
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
validation_dataset_id: Optional[str] = None
|
||||
packed: Optional[bool] = False
|
||||
train_on_input: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class OptimizerConfig(BaseModel):
|
||||
optimizer_type: OptimizerType
|
||||
lr: float
|
||||
lr_min: float
|
||||
weight_decay: float
|
||||
num_warmup_steps: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EfficiencyConfig(BaseModel):
|
||||
enable_activation_checkpointing: Optional[bool] = False
|
||||
enable_activation_offloading: Optional[bool] = False
|
||||
memory_efficient_fsdp_wrap: Optional[bool] = False
|
||||
fsdp_cpu_offload: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TrainingConfig(BaseModel):
|
||||
dtype: str
|
||||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
|
||||
enable_activation_checkpointing: bool
|
||||
memory_efficient_fsdp_wrap: Optional[bool]
|
||||
fsdp_cpu_offload: Optional[bool]
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
dtype: Optional[str] = "bf16"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FinetuningAlgorithm(Enum):
|
||||
full = "full"
|
||||
lora = "lora"
|
||||
qlora = "qlora"
|
||||
dora = "dora"
|
||||
qat = "qat"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -63,17 +76,14 @@ class LoraFinetuningConfig(BaseModel):
|
|||
apply_lora_to_output: bool
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool
|
||||
use_dora: Optional[bool] = False
|
||||
quantize_base: Optional[bool] = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
||||
pass
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DoraFinetuningConfig(LoraFinetuningConfig):
|
||||
pass
|
||||
class QATFinetuningConfig(BaseModel):
|
||||
quantizer_name: str
|
||||
group_size: int
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -110,13 +120,9 @@ class PostTrainingSFTRequest(BaseModel):
|
|||
"""Request to finetune a model."""
|
||||
|
||||
job_uuid: str
|
||||
|
||||
model: str
|
||||
dataset_id: str
|
||||
validation_dataset_id: str
|
||||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: LoraFinetuningConfig
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
|
@ -182,13 +188,11 @@ class PostTraining(Protocol):
|
|||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
algorithm_config: Optional[LoraFinetuningConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
|
|
|
@ -24,8 +24,6 @@ class MetaReferencePostTrainingImpl:
|
|||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
training_config: TrainingConfig,
|
||||
|
@ -37,8 +35,6 @@ class MetaReferencePostTrainingImpl:
|
|||
request = PostTrainingSFTRequest(
|
||||
job_uuid=job_uuid,
|
||||
model=model,
|
||||
dataset_id=dataset_id,
|
||||
validation_dataset_id=validation_dataset_id,
|
||||
algorithm=algorithm,
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
@ -15,6 +16,7 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
@ -29,7 +31,7 @@ from llama_stack.providers.inline.post_training.meta_reference.datasets.sft impo
|
|||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training
|
||||
from torchtune.data import InputOutputToMessages, padded_collate_sft
|
||||
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
|
@ -103,8 +105,8 @@ class LoraFinetuningSingleDevice:
|
|||
self.seed = training.set_seed(seed=config.torch_seed or 42)
|
||||
self.epochs_run = 0
|
||||
self.total_epochs = request.training_config.n_epochs
|
||||
self._shuffle = request.training_config.shuffle
|
||||
self._batch_size = request.training_config.batch_size
|
||||
self._shuffle = request.training_config.data_config.shuffle
|
||||
self._batch_size = request.training_config.data_config.batch_size
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
||||
|
@ -116,9 +118,15 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
request.training_config.enable_activation_checkpointing
|
||||
(request.training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if request.training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(request.training_config.efficiency_config.enable_activation_offloading)
|
||||
if request.training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = False
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
|
||||
|
@ -143,6 +151,9 @@ class LoraFinetuningSingleDevice:
|
|||
return checkpoint_dict
|
||||
|
||||
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
||||
# temporily log to local disk, will figure out how to interop with telemetry
|
||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||
|
||||
checkpoint_dict = await self.load_checkpoint()
|
||||
|
||||
self._model = await self._setup_model(
|
||||
|
@ -212,7 +223,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
|
||||
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
|
||||
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
|
||||
self._use_dora = self.request.algorithm_config.use_dora
|
||||
self._use_dora = self.request.algorithm_config.use_dora or False
|
||||
|
||||
with training.set_default_dtype(self._dtype), self._device:
|
||||
model_type = utils.get_model_type(self.model_id)
|
||||
|
@ -272,6 +283,10 @@ class LoraFinetuningSingleDevice:
|
|||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||
model, enable_activation_offloading
|
||||
)
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
|
||||
return model
|
||||
|
||||
async def _setup_tokenizer(
|
||||
|
@ -296,7 +311,7 @@ class LoraFinetuningSingleDevice:
|
|||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows():
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.request.dataset_id,
|
||||
dataset_id=self.request.training_config.data_config.dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
|
@ -306,7 +321,9 @@ class LoraFinetuningSingleDevice:
|
|||
# Curretly only support instruct dataset
|
||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||
ds = SFTDataset(
|
||||
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
|
||||
rows,
|
||||
message_transform=AlpacaToMessages(train_on_input=False),
|
||||
model_transform=tokenizer,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
|
@ -412,6 +429,7 @@ class LoraFinetuningSingleDevice:
|
|||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
# t0 = time.perf_counter()
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
|
||||
|
@ -447,7 +465,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Step with optimizer
|
||||
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||
training.scale_grads(self._model, 1 / num_tokens)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self._model.parameters(),
|
||||
max_norm=float(self._clip_grad_norm),
|
||||
)
|
||||
|
@ -457,9 +475,25 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss.item() / num_tokens
|
||||
time_per_step = time.perf_counter() - t0
|
||||
log_dict = {
|
||||
"loss": loss_to_log,
|
||||
"lr": self._optimizer.param_groups[0]["lr"],
|
||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
}
|
||||
log_dict.update(training.get_memory_stats(device=self._device))
|
||||
if self._clip_grad_norm is not None:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
self._metric_logger.log_dict(
|
||||
log_dict,
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue