temp commit

This commit is contained in:
Botao Chen 2024-11-26 10:49:03 -08:00
parent d7598c68d7
commit 9a976bcabd
9 changed files with 310 additions and 23 deletions

View file

@ -9,7 +9,7 @@ from typing import Any, List, Optional, Protocol
from urllib.parse import urlparse
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis import post_training
from llama_stack.apis.datasets import Dataset
from llama_stack.apis.eval_tasks import EvalTask
@ -17,6 +17,7 @@ from llama_stack.apis.memory_banks.memory_banks import MemoryBank
from llama_stack.apis.models import Model
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.shields import Shield
from pydantic import BaseModel, Field
@json_schema_type
@ -28,6 +29,7 @@ class Api(Enum):
datasetio = "datasetio"
scoring = "scoring"
eval = "eval"
post_training = "post_training"
telemetry = "telemetry"

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferencePostTrainingConfig
async def get_provider_impl(
config: MetaReferencePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
):
from .post_training import MetaReferencePostTrainingImpl
impl = MetaReferencePostTrainingImpl(
config,
deps[Api.datasetio],
)
# await impl.initialize()
return impl

View file

@ -5,7 +5,8 @@
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field,
from pydantic import BaseModel, Field
class MetaReferencePostTrainingConfig(BaseModel):

View file

@ -4,11 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping
import numpy as np
from datasets import load_dataset
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages

View file

@ -1,6 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.providers.inline.post_training.meta_reference.config import (
MetaReferencePostTrainingConfig,
)
from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
class MetaReferencePostTrainingImpl:
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
def __init__(
self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO
) -> None:
self.config = config
self.datasetio_api = datasetio_api
def supervised_fine_tune(
self,
@ -27,7 +45,12 @@ class MetaReferencePostTrainingImpl:
logger_config=logger_config,
)
if request.algorithm == FinetuningAlgorithm.lora:
recipe = LoraFinetuningRecipeSingleDevice(self.config, request)
recipe = LoraFinetuningSingleDevice(
self.config, request, self.datasetio_api
)
recipe.setup(self.config)
recipe.train()
else:
raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid)

View file

@ -3,15 +3,21 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import logging
import os
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from llama_stack.apis.datasetio import DatasetIO
from torch import nn
from llama_stack.apis.post_training import * # noqa
from llama_stack.apis.post_training import PostTrainingSFTRequest
from llama_stack.providers.inline.post_training.meta_reference import utils
from llama_stack.providers.inline.post_training.meta_reference.configs import (
from llama_stack.providers.inline.post_training.meta_reference.config import (
MetaReferencePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import (
@ -36,7 +42,7 @@ from torchtune.training.lr_scheduler import get_cosine_schedule_with_warmup
log = logging.getLogger(__name__)
Tokenizer = Union[Llama3Tokenizer]
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice:
@ -44,13 +50,13 @@ class LoraFinetuningSingleDevice:
self,
config: MetaReferencePostTrainingConfig,
request: PostTrainingSFTRequest,
datasetio_api: DatasetIOAPI,
datasetio_api: DatasetIO,
) -> None:
# to make user config easier, assume the device is 'cuda' only
# self._device = utils.get_device(device=cfg.device)
self.config = config
self.request = request
self._device = "cuda"
self._device = training.utils.get_device(device="cuda")
self._dtype = training.get_dtype(
request.training_config.dtype, device=self._device
)
@ -68,6 +74,10 @@ class LoraFinetuningSingleDevice:
self._shuffle = request.training_config.shuffle
self._batch_size = request.training_config.batch_size
self.checkpoint_dir = (
self.config.checkpoint_dir or f"~/.llama/checkpoints/{self.model_id}"
)
# this is important for debugging purpose
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
self.global_step = 0
@ -98,7 +108,7 @@ class LoraFinetuningSingleDevice:
pth_files = [file for file in files if file.endswith(".pth")]
return pth_files
except FileNotFoundError:
return f"Error: The directory '{checkpoint_dir}' does not exist."
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer(
checkpoint_dir=self.config.checkpoint_dir,
@ -133,7 +143,7 @@ class LoraFinetuningSingleDevice:
log.info("Tokenizer is initialized from file.")
self._optimizer = self._setup_optimizer(
optimizer_config=self.request.training_config.optimizer, opt_state_dict=None
optimizer_config=self.request.training_config.optimizer
)
self._loss_fn = CEWithChunkedOutputLoss()
@ -253,8 +263,8 @@ class LoraFinetuningSingleDevice:
def _setup_tokenizer(
self,
) -> Tokenizer:
tokenizer_path = self.config.checkpoint_dir + "/tokenizer.model"
) -> Llama3Tokenizer:
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
tokenizer_type = utils.get_tokenizer_type(self.model_id)
return tokenizer_type(path=tokenizer_path)
@ -270,13 +280,17 @@ class LoraFinetuningSingleDevice:
log.info("Optimizer and loss are initialized.")
return optimizer
async def _setup_data(
self, tokenizer: Tokenizer, shuffle: bool, batch_size: int
def _setup_data(
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
) -> Tuple[DistributedSampler, DataLoader]:
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=self.request.dataset_id,
rows_in_page=-1,
)
async def fetch_rows():
return await self.datasetio_api.get_rows_paginated(
dataset_id=self.request.dataset_id,
rows_in_page=-1,
)
# Run the async function in an event loop
all_rows = asyncio.run(fetch_rows())
rows = all_rows.rows
ds = SFTDataset(
@ -323,3 +337,198 @@ class LoraFinetuningSingleDevice:
log.info("Learning rate scheduler is initialized.")
return lr_scheduler
def save_checkpoint(self, epoch: int) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
"""
ckpt_dict = {}
intermediate_checkpoint = epoch + 1 < self.total_epochs
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
# Construct the full state dict with LoRA weights merged into base LLM weights
# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
)
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
# run model
with self.activations_handling_ctx:
logits = self._model(**batch)
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
return loss
def train(self) -> None:
"""
The core training loop.
"""
# if self._compile:
# log.info(
# "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
# )
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
# pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
# Start tracking CUDA memory for active steps for just the first epoch
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx == self.profiler_wait_steps + self.profiler_warmup_steps
# ):
# torch.cuda.memory._record_memory_history()
training.utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1
# loss_to_log = running_loss.item() / num_tokens
# pbar.update(1)
# pbar.set_description(
# f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
# )
# Log per-step metrics
# if self.global_step % self._log_every_n_steps == 0:
# time_per_step = time.perf_counter() - t0
# log_dict = {
# "loss": loss_to_log,
# "lr": self._optimizer.param_groups[0]["lr"],
# "tokens_per_second_per_gpu": num_tokens / time_per_step,
# }
# if self._device.type == "cuda" and self._log_peak_memory_stats:
# log_dict.update(
# training.get_memory_stats(device=self._device)
# )
# if self._clip_grad_norm is not None:
# log_dict.update({"grad_norm": grad_norm})
# self._metric_logger.log_dict(
# log_dict,
# step=self.global_step,
# )
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
# t0 = time.perf_counter()
# Stop tracking CUDA memory now that active steps are complete
# if (
# curr_epoch == 0
# and self.profiler_profile_memory
# and idx
# == self.profiler_wait_steps
# + self.profiler_warmup_steps
# + self.profiler_active_steps
# ):
# torch.cuda.memory._record_memory_history(enabled=None)
# Step the profiler
# Note we are stepping each batch, which might not include optimizer step in the trace
# if the schedule cycle doesn't align with gradient accumulation.
# prof.step()
self.epochs_run += 1
# start_save_checkpoint = time.perf_counter()
log.info("Starting checkpoint save...")
self.save_checkpoint(epoch=curr_epoch)
# log.info(
# "Checkpoint saved in {:.2f} seconds.".format(
# time.perf_counter() - start_save_checkpoint
# )
# )

View file

@ -30,15 +30,13 @@ BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def get_model_type(
self,
model_id: str,
) -> BuildLoraModelCallable:
model = resolve_model(model_id)
return LORA_MODEL_TYPES[model.core_model_id.value]
def get_tokenizer(
self,
def get_tokenizer_type(
model_id: str,
) -> BuildTokenizerCallable:
model = resolve_model(model_id)

View file

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
META_REFERENCE_DEPS = [
"torch",
"torchtune",
"numpy",
]
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::meta-reference",
pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.inline.post_training.meta_reference",
config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig",
),
]

View file

@ -4,6 +4,8 @@ distribution_spec:
description: Use Meta Reference for running LLM inference
docker_image: null
providers:
post_training:
- inline::meta-reference
inference:
- inline::meta-reference
memory: