mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
refine api
This commit is contained in:
parent
5838b7211d
commit
41cf2bb0a7
3 changed files with 74 additions and 40 deletions
|
@ -7,7 +7,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
@ -18,42 +18,55 @@ from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
class OptimizerType(Enum):
|
class OptimizerType(Enum):
|
||||||
adam = "adam"
|
adam = "adam"
|
||||||
adamw = "adamw"
|
adamw = "adamw"
|
||||||
sgd = "sgd"
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DataConfig(BaseModel):
|
||||||
|
dataset_id: str
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
validation_dataset_id: Optional[str] = None
|
||||||
|
packed: Optional[bool] = False
|
||||||
|
train_on_input: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OptimizerConfig(BaseModel):
|
class OptimizerConfig(BaseModel):
|
||||||
optimizer_type: OptimizerType
|
optimizer_type: OptimizerType
|
||||||
lr: float
|
lr: float
|
||||||
lr_min: float
|
|
||||||
weight_decay: float
|
weight_decay: float
|
||||||
num_warmup_steps: int
|
num_warmup_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EfficiencyConfig(BaseModel):
|
||||||
|
enable_activation_checkpointing: Optional[bool] = False
|
||||||
|
enable_activation_offloading: Optional[bool] = False
|
||||||
|
memory_efficient_fsdp_wrap: Optional[bool] = False
|
||||||
|
fsdp_cpu_offload: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
dtype: str
|
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int
|
max_steps_per_epoch: int
|
||||||
gradient_accumulation_steps: int
|
gradient_accumulation_steps: int
|
||||||
batch_size: int
|
data_config: DataConfig
|
||||||
shuffle: bool
|
|
||||||
optimizer_config: OptimizerConfig
|
optimizer_config: OptimizerConfig
|
||||||
|
efficiency_config: Optional[EfficiencyConfig] = None
|
||||||
enable_activation_checkpointing: bool
|
dtype: Optional[str] = "bf16"
|
||||||
memory_efficient_fsdp_wrap: Optional[bool]
|
|
||||||
fsdp_cpu_offload: Optional[bool]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FinetuningAlgorithm(Enum):
|
class FinetuningAlgorithm(Enum):
|
||||||
full = "full"
|
full = "full"
|
||||||
lora = "lora"
|
lora = "lora"
|
||||||
qlora = "qlora"
|
qat = "qat"
|
||||||
dora = "dora"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -63,17 +76,14 @@ class LoraFinetuningConfig(BaseModel):
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: bool
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
use_dora: bool
|
use_dora: Optional[bool] = False
|
||||||
|
quantize_base: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
class QATFinetuningConfig(BaseModel):
|
||||||
pass
|
quantizer_name: str
|
||||||
|
group_size: int
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class DoraFinetuningConfig(LoraFinetuningConfig):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -110,13 +120,9 @@ class PostTrainingSFTRequest(BaseModel):
|
||||||
"""Request to finetune a model."""
|
"""Request to finetune a model."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
dataset_id: str
|
|
||||||
validation_dataset_id: str
|
|
||||||
|
|
||||||
algorithm: FinetuningAlgorithm
|
algorithm: FinetuningAlgorithm
|
||||||
algorithm_config: LoraFinetuningConfig
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]] = None
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig
|
||||||
|
|
||||||
# TODO: define these
|
# TODO: define these
|
||||||
|
@ -182,13 +188,11 @@ class PostTraining(Protocol):
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
model: str,
|
model: str,
|
||||||
dataset_id: str,
|
|
||||||
validation_dataset_id: str,
|
|
||||||
algorithm: FinetuningAlgorithm,
|
algorithm: FinetuningAlgorithm,
|
||||||
algorithm_config: LoraFinetuningConfig,
|
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
|
algorithm_config: Optional[LoraFinetuningConfig] = None,
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize")
|
@webmethod(route="/post-training/preference-optimize")
|
||||||
|
|
|
@ -24,8 +24,6 @@ class MetaReferencePostTrainingImpl:
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
model: str,
|
model: str,
|
||||||
dataset_id: str,
|
|
||||||
validation_dataset_id: str,
|
|
||||||
algorithm: FinetuningAlgorithm,
|
algorithm: FinetuningAlgorithm,
|
||||||
algorithm_config: LoraFinetuningConfig,
|
algorithm_config: LoraFinetuningConfig,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
|
@ -37,8 +35,6 @@ class MetaReferencePostTrainingImpl:
|
||||||
request = PostTrainingSFTRequest(
|
request = PostTrainingSFTRequest(
|
||||||
job_uuid=job_uuid,
|
job_uuid=job_uuid,
|
||||||
model=model,
|
model=model,
|
||||||
dataset_id=dataset_id,
|
|
||||||
validation_dataset_id=validation_dataset_id,
|
|
||||||
algorithm=algorithm,
|
algorithm=algorithm,
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=training_config,
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
@ -15,6 +16,7 @@ from llama_models.sku_list import resolve_model
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
@ -29,7 +31,7 @@ from llama_stack.providers.inline.post_training.meta_reference.datasets.sft impo
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
from torchtune import modules, training
|
from torchtune import modules, training
|
||||||
from torchtune.data import InputOutputToMessages, padded_collate_sft
|
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||||
|
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
from torchtune.modules.peft import (
|
from torchtune.modules.peft import (
|
||||||
|
@ -103,8 +105,8 @@ class LoraFinetuningSingleDevice:
|
||||||
self.seed = training.set_seed(seed=config.torch_seed or 42)
|
self.seed = training.set_seed(seed=config.torch_seed or 42)
|
||||||
self.epochs_run = 0
|
self.epochs_run = 0
|
||||||
self.total_epochs = request.training_config.n_epochs
|
self.total_epochs = request.training_config.n_epochs
|
||||||
self._shuffle = request.training_config.shuffle
|
self._shuffle = request.training_config.data_config.shuffle
|
||||||
self._batch_size = request.training_config.batch_size
|
self._batch_size = request.training_config.data_config.batch_size
|
||||||
|
|
||||||
# this is important for debugging purpose
|
# this is important for debugging purpose
|
||||||
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
||||||
|
@ -116,9 +118,15 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
self._clip_grad_norm = 1.0
|
self._clip_grad_norm = 1.0
|
||||||
self._enable_activation_checkpointing = (
|
self._enable_activation_checkpointing = (
|
||||||
request.training_config.enable_activation_checkpointing
|
(request.training_config.efficiency_config.enable_activation_checkpointing)
|
||||||
|
if request.training_config.efficiency_config
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self._enable_activation_offloading = (
|
||||||
|
(request.training_config.efficiency_config.enable_activation_offloading)
|
||||||
|
if request.training_config.efficiency_config
|
||||||
|
else False
|
||||||
)
|
)
|
||||||
self._enable_activation_offloading = False
|
|
||||||
|
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
|
|
||||||
|
@ -143,6 +151,9 @@ class LoraFinetuningSingleDevice:
|
||||||
return checkpoint_dict
|
return checkpoint_dict
|
||||||
|
|
||||||
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
||||||
|
# temporily log to local disk, will figure out how to interop with telemetry
|
||||||
|
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||||
|
|
||||||
checkpoint_dict = await self.load_checkpoint()
|
checkpoint_dict = await self.load_checkpoint()
|
||||||
|
|
||||||
self._model = await self._setup_model(
|
self._model = await self._setup_model(
|
||||||
|
@ -212,7 +223,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
|
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
|
||||||
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
|
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
|
||||||
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
|
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
|
||||||
self._use_dora = self.request.algorithm_config.use_dora
|
self._use_dora = self.request.algorithm_config.use_dora or False
|
||||||
|
|
||||||
with training.set_default_dtype(self._dtype), self._device:
|
with training.set_default_dtype(self._dtype), self._device:
|
||||||
model_type = utils.get_model_type(self.model_id)
|
model_type = utils.get_model_type(self.model_id)
|
||||||
|
@ -272,6 +283,10 @@ class LoraFinetuningSingleDevice:
|
||||||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||||
model, enable_activation_offloading
|
model, enable_activation_offloading
|
||||||
)
|
)
|
||||||
|
|
||||||
|
memory_stats = training.get_memory_stats(device=self._device)
|
||||||
|
training.log_memory_stats(memory_stats)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def _setup_tokenizer(
|
async def _setup_tokenizer(
|
||||||
|
@ -296,7 +311,7 @@ class LoraFinetuningSingleDevice:
|
||||||
) -> Tuple[DistributedSampler, DataLoader]:
|
) -> Tuple[DistributedSampler, DataLoader]:
|
||||||
async def fetch_rows():
|
async def fetch_rows():
|
||||||
return await self.datasetio_api.get_rows_paginated(
|
return await self.datasetio_api.get_rows_paginated(
|
||||||
dataset_id=self.request.dataset_id,
|
dataset_id=self.request.training_config.data_config.dataset_id,
|
||||||
rows_in_page=-1,
|
rows_in_page=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -306,7 +321,9 @@ class LoraFinetuningSingleDevice:
|
||||||
# Curretly only support instruct dataset
|
# Curretly only support instruct dataset
|
||||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||||
ds = SFTDataset(
|
ds = SFTDataset(
|
||||||
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
|
rows,
|
||||||
|
message_transform=AlpacaToMessages(train_on_input=False),
|
||||||
|
model_transform=tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler = DistributedSampler(
|
sampler = DistributedSampler(
|
||||||
|
@ -412,6 +429,7 @@ class LoraFinetuningSingleDevice:
|
||||||
"""
|
"""
|
||||||
# Initialize tokens count and running loss (for grad accumulation)
|
# Initialize tokens count and running loss (for grad accumulation)
|
||||||
# t0 = time.perf_counter()
|
# t0 = time.perf_counter()
|
||||||
|
t0 = time.perf_counter()
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
|
|
||||||
|
@ -447,7 +465,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# Step with optimizer
|
# Step with optimizer
|
||||||
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||||
training.scale_grads(self._model, 1 / num_tokens)
|
training.scale_grads(self._model, 1 / num_tokens)
|
||||||
torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
self._model.parameters(),
|
self._model.parameters(),
|
||||||
max_norm=float(self._clip_grad_norm),
|
max_norm=float(self._clip_grad_norm),
|
||||||
)
|
)
|
||||||
|
@ -457,9 +475,25 @@ class LoraFinetuningSingleDevice:
|
||||||
# Update the number of steps when the weights are updated
|
# Update the number of steps when the weights are updated
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
|
loss_to_log = running_loss.item() / num_tokens
|
||||||
|
time_per_step = time.perf_counter() - t0
|
||||||
|
log_dict = {
|
||||||
|
"loss": loss_to_log,
|
||||||
|
"lr": self._optimizer.param_groups[0]["lr"],
|
||||||
|
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||||
|
}
|
||||||
|
log_dict.update(training.get_memory_stats(device=self._device))
|
||||||
|
if self._clip_grad_norm is not None:
|
||||||
|
log_dict.update({"grad_norm": grad_norm})
|
||||||
|
self._metric_logger.log_dict(
|
||||||
|
log_dict,
|
||||||
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
# Reset running stats for the next step
|
# Reset running stats for the next step
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
self.epochs_run += 1
|
self.epochs_run += 1
|
||||||
log.info("Starting checkpoint save...")
|
log.info("Starting checkpoint save...")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue