mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:12:46 +00:00
refine api
This commit is contained in:
parent
5838b7211d
commit
41cf2bb0a7
3 changed files with 74 additions and 40 deletions
|
|
@ -24,8 +24,6 @@ class MetaReferencePostTrainingImpl:
|
|||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
dataset_id: str,
|
||||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
training_config: TrainingConfig,
|
||||
|
|
@ -37,8 +35,6 @@ class MetaReferencePostTrainingImpl:
|
|||
request = PostTrainingSFTRequest(
|
||||
job_uuid=job_uuid,
|
||||
model=model,
|
||||
dataset_id=dataset_id,
|
||||
validation_dataset_id=validation_dataset_id,
|
||||
algorithm=algorithm,
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
|
@ -15,6 +16,7 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
|
@ -29,7 +31,7 @@ from llama_stack.providers.inline.post_training.meta_reference.datasets.sft impo
|
|||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training
|
||||
from torchtune.data import InputOutputToMessages, padded_collate_sft
|
||||
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
|
|
@ -103,8 +105,8 @@ class LoraFinetuningSingleDevice:
|
|||
self.seed = training.set_seed(seed=config.torch_seed or 42)
|
||||
self.epochs_run = 0
|
||||
self.total_epochs = request.training_config.n_epochs
|
||||
self._shuffle = request.training_config.shuffle
|
||||
self._batch_size = request.training_config.batch_size
|
||||
self._shuffle = request.training_config.data_config.shuffle
|
||||
self._batch_size = request.training_config.data_config.batch_size
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
||||
|
|
@ -116,9 +118,15 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
request.training_config.enable_activation_checkpointing
|
||||
(request.training_config.efficiency_config.enable_activation_checkpointing)
|
||||
if request.training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = (
|
||||
(request.training_config.efficiency_config.enable_activation_offloading)
|
||||
if request.training_config.efficiency_config
|
||||
else False
|
||||
)
|
||||
self._enable_activation_offloading = False
|
||||
|
||||
self.datasetio_api = datasetio_api
|
||||
|
||||
|
|
@ -143,6 +151,9 @@ class LoraFinetuningSingleDevice:
|
|||
return checkpoint_dict
|
||||
|
||||
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
||||
# temporily log to local disk, will figure out how to interop with telemetry
|
||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||
|
||||
checkpoint_dict = await self.load_checkpoint()
|
||||
|
||||
self._model = await self._setup_model(
|
||||
|
|
@ -212,7 +223,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._lora_attn_modules = list(self.request.algorithm_config.lora_attn_modules)
|
||||
self._apply_lora_to_mlp = self.request.algorithm_config.apply_lora_to_mlp
|
||||
self._apply_lora_to_output = self.request.algorithm_config.apply_lora_to_output
|
||||
self._use_dora = self.request.algorithm_config.use_dora
|
||||
self._use_dora = self.request.algorithm_config.use_dora or False
|
||||
|
||||
with training.set_default_dtype(self._dtype), self._device:
|
||||
model_type = utils.get_model_type(self.model_id)
|
||||
|
|
@ -272,6 +283,10 @@ class LoraFinetuningSingleDevice:
|
|||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||
model, enable_activation_offloading
|
||||
)
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
training.log_memory_stats(memory_stats)
|
||||
|
||||
return model
|
||||
|
||||
async def _setup_tokenizer(
|
||||
|
|
@ -296,7 +311,7 @@ class LoraFinetuningSingleDevice:
|
|||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows():
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.request.dataset_id,
|
||||
dataset_id=self.request.training_config.data_config.dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
|
|
@ -306,7 +321,9 @@ class LoraFinetuningSingleDevice:
|
|||
# Curretly only support instruct dataset
|
||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||
ds = SFTDataset(
|
||||
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
|
||||
rows,
|
||||
message_transform=AlpacaToMessages(train_on_input=False),
|
||||
model_transform=tokenizer,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(
|
||||
|
|
@ -412,6 +429,7 @@ class LoraFinetuningSingleDevice:
|
|||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
# t0 = time.perf_counter()
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
|
||||
|
|
@ -447,7 +465,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Step with optimizer
|
||||
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||
training.scale_grads(self._model, 1 / num_tokens)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self._model.parameters(),
|
||||
max_norm=float(self._clip_grad_norm),
|
||||
)
|
||||
|
|
@ -457,9 +475,25 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
loss_to_log = running_loss.item() / num_tokens
|
||||
time_per_step = time.perf_counter() - t0
|
||||
log_dict = {
|
||||
"loss": loss_to_log,
|
||||
"lr": self._optimizer.param_groups[0]["lr"],
|
||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
}
|
||||
log_dict.update(training.get_memory_stats(device=self._device))
|
||||
if self._clip_grad_norm is not None:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
self._metric_logger.log_dict(
|
||||
log_dict,
|
||||
step=self.global_step,
|
||||
)
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
t0 = time.perf_counter()
|
||||
|
||||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue