temp commit

This commit is contained in:
Botao Chen 2024-12-02 17:24:25 -08:00
parent 6c709abc4d
commit 79c525be94
5 changed files with 115 additions and 163 deletions

View file

@ -41,7 +41,7 @@ class TrainingConfig(BaseModel):
gradient_accumulation_steps: int
batch_size: int
shuffle: bool
# n_iters: int
optimizer_config: OptimizerConfig
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: Optional[bool]
@ -63,6 +63,7 @@ class LoraFinetuningConfig(BaseModel):
apply_lora_to_output: bool
rank: int
alpha: int
use_dora: bool
@json_schema_type
@ -116,7 +117,6 @@ class PostTrainingSFTRequest(BaseModel):
algorithm: FinetuningAlgorithm
algorithm_config: LoraFinetuningConfig
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
@ -178,7 +178,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune(
async def supervised_fine_tune(
self,
job_uuid: str,
model: str,
@ -186,14 +186,14 @@ class PostTraining(Protocol):
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig,
# optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
def preference_optimize(
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
@ -208,21 +208,23 @@ class PostTraining(Protocol):
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ...
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
def get_training_job_status(
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts(
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -20,7 +20,7 @@ class MetaReferencePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
def supervised_fine_tune(
async def supervised_fine_tune(
self,
job_uuid: str,
model: str,
@ -28,11 +28,11 @@ class MetaReferencePostTrainingImpl:
validation_dataset_id: str,
algorithm: FinetuningAlgorithm,
algorithm_config: LoraFinetuningConfig,
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob:
# wrapper request to make it easier to pass around (internal only, not exposed to API)
request = PostTrainingSFTRequest(
job_uuid=job_uuid,
@ -41,7 +41,6 @@ class MetaReferencePostTrainingImpl:
validation_dataset_id=validation_dataset_id,
algorithm=algorithm,
algorithm_config=algorithm_config,
optimizer_config=optimizer_config,
training_config=training_config,
hyperparam_search_config=hyperparam_search_config,
logger_config=logger_config,
@ -50,14 +49,14 @@ class MetaReferencePostTrainingImpl:
recipe = LoraFinetuningSingleDevice(
self.config, request, self.datasetio_api
)
recipe.setup(self.config)
recipe.train()
await recipe.setup(self.config)
await recipe.train()
else:
raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid)
def preference_optimize(
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: URL,
@ -71,21 +70,24 @@ class MetaReferencePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
def get_training_jobs(self) -> List[PostTrainingJob]: ...
# TODO @markchen1015 impelment below APIs
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
def get_training_job_status(
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts(
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...

View file

@ -7,15 +7,20 @@
import asyncio
import logging
import os
import re
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.checkpointing._utils import ModelType
from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.meta_reference import utils
from llama_stack.providers.inline.post_training.meta_reference.config import (
@ -47,14 +52,24 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice:
# This recipe only supports GPU training
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile
# - activation offloading
# Resume from checkpoint hasn't been supported yet
# Validation hasn't been supported yet
# TODO @markchen1015 figure out the logging for this training recipe
# and make it work with telemetry
def __init__(
self,
config: MetaReferencePostTrainingConfig,
request: PostTrainingSFTRequest,
datasetio_api: DatasetIO,
) -> None:
# to make user config easier, assume the device is 'cuda' only
# self._device = utils.get_device(device=cfg.device)
# Assume the training only happens on GPU
self.config = config
self.request = request
self._device = torchtune_utils.get_device(device="cuda")
@ -63,11 +78,30 @@ class LoraFinetuningSingleDevice:
)
self.model_id = config.model
# hardcode it for now and see how it works with get_training_job_artifacts
self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}"
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
self._log_every_n_steps = 1
self._log_peak_memory_stats = False
paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
)
return str(checkpoint_dir)
if config.checkpoint_dir and config.checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir
else:
model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @markchen1015 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/"
self.seed = training.set_seed(seed=config.torch_seed or 42)
self.epochs_run = 0
@ -75,23 +109,15 @@ class LoraFinetuningSingleDevice:
self._shuffle = request.training_config.shuffle
self._batch_size = request.training_config.batch_size
self.checkpoint_dir = (
self.config.checkpoint_dir or f"~/.llama/checkpoints/{self.model_id}"
)
# this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
self.global_step = 0
# not needed in MVP
# self._resume_from_checkpoint = cfg.resume_from_checkpoint
# self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._gradient_accumulation_steps = (
request.training_config.gradient_accumulation_steps
)
self._clip_grad_norm = 1.0 # hardcode
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
request.training_config.enable_activation_checkpointing
)
@ -99,12 +125,11 @@ class LoraFinetuningSingleDevice:
self.datasetio_api = datasetio_api
def load_checkpoint(self):
async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
try:
# List all files in the given directory
files = os.listdir(checkpoint_dir)
# Filter files that end with .pth
pth_files = [file for file in files if file.endswith(".pth")]
return pth_files
@ -115,44 +140,40 @@ class LoraFinetuningSingleDevice:
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
# todo: automatically get this info from model
model_type="LLAMA3",
model_type=utils.get_checkpointer_model_type(self.model_id),
)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict
def setup(self, config: MetaReferencePostTrainingConfig) -> None:
# todo: figure out how does it works with telemetry
# self._metric_logger = config.instantiate(cfg.metric_logger)
# self._metric_logger.log_config(cfg)
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
checkpoint_dict = await self.load_checkpoint()
checkpoint_dict = self.load_checkpoint()
# hack to toggle to the low cpu ram version of the reparametrize_as_dtype
# hook based on the config.
# common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False)
# set up model
self._model = self._setup_model(
self._model = await self._setup_model(
enable_activation_checkpointing=self._enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=None,
)
log.info(f"Model is initialized with precision {self._dtype}.")
self._tokenizer = self._setup_tokenizer()
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized from file.")
self._optimizer = self._setup_optimizer(
optimizer_config=self.request.training_config.optimizer
self._optimizer = await self._setup_optimizer(
optimizer_config=self.request.training_config.optimizer_config
)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
self._sampler, self._dataloader = self._setup_data(
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
self._sampler, self._dataloader = await self._setup_data(
tokenizer=self._tokenizer,
shuffle=self._shuffle,
batch_size=self._batch_size,
)
log.info("Dataset and Sampler are initialized.")
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader and the max_steps_per_epoch param set by the user and is used
@ -170,18 +191,19 @@ class LoraFinetuningSingleDevice:
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
num_warmup_steps=self.request.optimizer_config.num_warmup_steps,
self._lr_scheduler = await self._setup_lr_scheduler(
num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
def _setup_model(
async def _setup_model(
self,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
@ -243,9 +265,8 @@ class LoraFinetuningSingleDevice:
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)
# Validate model adapter params were loaded in with the expected dtype
# TODO (rohan-varma): Further validation to ensure the appropriate base params
# are NF4 vs bf16 based on the quantization config.
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
@ -254,22 +275,16 @@ class LoraFinetuningSingleDevice:
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
log.info(f"Model is initialized with precision {self._dtype}.")
# if self._device.type != "cpu":
# memory_stats = training.get_memory_stats(device=self._device)
# training.log_memory_stats(memory_stats)
return model
def _setup_tokenizer(
async def _setup_tokenizer(
self,
) -> Llama3Tokenizer:
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
tokenizer_type = utils.get_tokenizer_type(self.model_id)
return tokenizer_type(path=tokenizer_path)
def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
optimizer = torch.optim.AdamW(
params=self._model.parameters(),
lr=optimizer_config.lr,
@ -277,11 +292,9 @@ class LoraFinetuningSingleDevice:
eps=1e-8,
weight_decay=0.1,
)
log.info("Optimizer and loss are initialized.")
return optimizer
def _setup_data(
async def _setup_data(
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows():
@ -290,10 +303,11 @@ class LoraFinetuningSingleDevice:
rows_in_page=-1,
)
# Run the async function in an event loop
all_rows = asyncio.run(fetch_rows())
all_rows = await fetch_rows()
rows = all_rows.rows
# Curretly only support instruct dataset
# TODO @markchen1015 make the message_transform swappable and support more dataset types
ds = SFTDataset(
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
)
@ -320,11 +334,9 @@ class LoraFinetuningSingleDevice:
),
)
log.info("Dataset and Sampler are initialized.")
return sampler, dataloader
def _setup_lr_scheduler(
async def _setup_lr_scheduler(
self,
num_warmup_steps: int,
num_training_steps: int,
@ -332,33 +344,19 @@ class LoraFinetuningSingleDevice:
) -> Optimizer:
lr_scheduler = get_cosine_schedule_with_warmup(
self._optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
log.info("Learning rate scheduler is initialized.")
return lr_scheduler
def save_checkpoint(self, epoch: int) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
"""
async def save_checkpoint(self, epoch: int) -> None:
ckpt_dict = {}
intermediate_checkpoint = epoch + 1 < self.total_epochs
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
# Construct the full state dict with LoRA weights merged into base LLM weights
# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
@ -385,10 +383,9 @@ class LoraFinetuningSingleDevice:
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
# run model
@ -412,16 +409,10 @@ class LoraFinetuningSingleDevice:
return loss
def train(self) -> None:
async def train(self) -> None:
"""
The core training loop.
"""
# if self._compile:
# log.info(
# "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
# )
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
running_loss = 0
@ -433,7 +424,6 @@ class LoraFinetuningSingleDevice:
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
# pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
@ -442,14 +432,6 @@ class LoraFinetuningSingleDevice:
):
break
# Start tracking CUDA memory for active steps for just the first epoch
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx == self.profiler_wait_steps + self.profiler_warmup_steps
# ):
# torch.cuda.memory._record_memory_history()
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
@ -461,14 +443,14 @@ class LoraFinetuningSingleDevice:
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_step(batch) * current_num_tokens
current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
grad_norm = torch.nn.utils.clip_grad_norm_(
torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
@ -478,58 +460,10 @@ class LoraFinetuningSingleDevice:
# Update the number of steps when the weights are updated
self.global_step += 1
# loss_to_log = running_loss.item() / num_tokens
# pbar.update(1)
# pbar.set_description(
# f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
# )
# Log per-step metrics
# if self.global_step % self._log_every_n_steps == 0:
# time_per_step = time.perf_counter() - t0
# log_dict = {
# "loss": loss_to_log,
# "lr": self._optimizer.param_groups[0]["lr"],
# "tokens_per_second_per_gpu": num_tokens / time_per_step,
# }
# if self._device.type == "cuda" and self._log_peak_memory_stats:
# log_dict.update(
# training.get_memory_stats(device=self._device)
# )
# if self._clip_grad_norm is not None:
# log_dict.update({"grad_norm": grad_norm})
# self._metric_logger.log_dict(
# log_dict,
# step=self.global_step,
# )
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
# t0 = time.perf_counter()
# Stop tracking CUDA memory now that active steps are complete
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx
# == self.profiler_wait_steps
# + self.profiler_warmup_steps
# + self.profiler_active_steps
# ):
# torch.cuda.memory._record_memory_history(enabled=None)
# Step the profiler
# Note we are stepping each batch, which might not include optimizer step in the trace
# if the schedule cycle doesn't align with gradient accumulation.
# prof.step()
self.epochs_run += 1
# start_save_checkpoint = time.perf_counter()
log.info("Starting checkpoint save...")
self.save_checkpoint(epoch=curr_epoch)
# log.info(
# "Checkpoint saved in {:.2f} seconds.".format(
# time.perf_counter() - start_save_checkpoint
# )
# )
await self.save_checkpoint(epoch=curr_epoch)

View file

@ -16,15 +16,22 @@ import torch
from llama_models.sku_list import resolve_model
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b
LORA_MODEL_TYPES: Dict[str, Any] = {
"Llama3.2-3B-Instruct": lora_llama3_2_3b,
"Llama-3-8B-Instruct": lora_llama3_8b,
}
TOKENIZER_TYPES: Dict[str, Any] = {
"Llama3.2-3B-Instruct": llama3_tokenizer,
"Llama-3-8B-Instruct": llama3_tokenizer,
}
CHECKPOINT_MODEL_TYPES: Dict[str, str] = {
"Llama3.2-3B-Instruct": "LLAMA3_2",
}
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
@ -41,3 +48,10 @@ def get_tokenizer_type(
) -> BuildTokenizerCallable:
model = resolve_model(model_id)
return TOKENIZER_TYPES[model.core_model_id.value]
def get_checkpointer_model_type(
model_id: str,
) -> str:
model = resolve_model(model_id)
return CHECKPOINT_MODEL_TYPES[model.core_model_id.value]

View file

@ -71,7 +71,7 @@ datasets:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name: post_training_alpaca
name:
split: train
dataset_schema:
instruction: