diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 080204e45..4837d9b38 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -9,7 +9,7 @@ from typing import Any, List, Optional, Protocol from urllib.parse import urlparse from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field +from llama_stack.apis import post_training from llama_stack.apis.datasets import Dataset from llama_stack.apis.eval_tasks import EvalTask @@ -17,6 +17,7 @@ from llama_stack.apis.memory_banks.memory_banks import MemoryBank from llama_stack.apis.models import Model from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.shields import Shield +from pydantic import BaseModel, Field @json_schema_type @@ -28,6 +29,7 @@ class Api(Enum): datasetio = "datasetio" scoring = "scoring" eval = "eval" + post_training = "post_training" telemetry = "telemetry" diff --git a/llama_stack/providers/inline/post_training/meta_reference/__init__.py b/llama_stack/providers/inline/post_training/meta_reference/__init__.py new file mode 100644 index 000000000..d700fbb0a --- /dev/null +++ b/llama_stack/providers/inline/post_training/meta_reference/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.distribution.datatypes import Api, ProviderSpec + +from .config import MetaReferencePostTrainingConfig + + +async def get_provider_impl( + config: MetaReferencePostTrainingConfig, + deps: Dict[Api, ProviderSpec], +): + from .post_training import MetaReferencePostTrainingImpl + + impl = MetaReferencePostTrainingImpl( + config, + deps[Api.datasetio], + ) + # await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/post_training/meta_reference/config.py b/llama_stack/providers/inline/post_training/meta_reference/config.py index 808613ae9..880fb6070 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/config.py +++ b/llama_stack/providers/inline/post_training/meta_reference/config.py @@ -5,7 +5,8 @@ # the root directory of this source tree. from typing import Optional -from pydantic import BaseModel, Field, + +from pydantic import BaseModel, Field class MetaReferencePostTrainingConfig(BaseModel): diff --git a/llama_stack/providers/inline/post_training/meta_reference/datasets/sft.py b/llama_stack/providers/inline/post_training/meta_reference/datasets/sft.py index 996a96109..c035fbedb 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/meta_reference/datasets/sft.py @@ -4,11 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Callable, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping import numpy as np -from datasets import load_dataset from torch.utils.data import Dataset from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages diff --git a/llama_stack/providers/inline/post_training/meta_reference/post_training.py b/llama_stack/providers/inline/post_training/meta_reference/post_training.py index 2311676ea..31ff9786c 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/post_training.py +++ b/llama_stack/providers/inline/post_training/meta_reference/post_training.py @@ -1,6 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.providers.inline.post_training.meta_reference.config import ( + MetaReferencePostTrainingConfig, +) +from llama_stack.apis.post_training import * # noqa +from llama_stack.providers.inline.post_training.meta_reference.recipes.lora_finetuning_single_device import ( + LoraFinetuningSingleDevice, +) + + class MetaReferencePostTrainingImpl: - def __init__(self, config: MetaReferenceInferenceConfig) -> None: + def __init__( + self, config: MetaReferencePostTrainingConfig, datasetio_api: DatasetIO + ) -> None: self.config = config + self.datasetio_api = datasetio_api def supervised_fine_tune( self, @@ -27,7 +45,12 @@ class MetaReferencePostTrainingImpl: logger_config=logger_config, ) if request.algorithm == FinetuningAlgorithm.lora: - recipe = LoraFinetuningRecipeSingleDevice(self.config, request) + recipe = LoraFinetuningSingleDevice( + self.config, request, self.datasetio_api + ) + recipe.setup(self.config) recipe.train() else: raise NotImplementedError() + + return PostTrainingJob(job_uuid=job_uuid) diff --git a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py index 9d595d4e7..acf302220 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py @@ -3,15 +3,21 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +import asyncio import logging import os from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import torch +from llama_stack.apis.datasetio import DatasetIO +from torch import nn +from llama_stack.apis.post_training import * # noqa +from llama_stack.apis.post_training import PostTrainingSFTRequest from llama_stack.providers.inline.post_training.meta_reference import utils -from llama_stack.providers.inline.post_training.meta_reference.configs import ( +from llama_stack.providers.inline.post_training.meta_reference.config import ( MetaReferencePostTrainingConfig, ) from llama_stack.providers.inline.post_training.meta_reference.datasets.sft import ( @@ -36,7 +42,7 @@ from torchtune.training.lr_scheduler import get_cosine_schedule_with_warmup log = logging.getLogger(__name__) -Tokenizer = Union[Llama3Tokenizer] +from torchtune.models.llama3._tokenizer import Llama3Tokenizer class LoraFinetuningSingleDevice: @@ -44,13 +50,13 @@ class LoraFinetuningSingleDevice: self, config: MetaReferencePostTrainingConfig, request: PostTrainingSFTRequest, - datasetio_api: DatasetIOAPI, + datasetio_api: DatasetIO, ) -> None: # to make user config easier, assume the device is 'cuda' only # self._device = utils.get_device(device=cfg.device) self.config = config self.request = request - self._device = "cuda" + self._device = training.utils.get_device(device="cuda") self._dtype = training.get_dtype( request.training_config.dtype, device=self._device ) @@ -68,6 +74,10 @@ class LoraFinetuningSingleDevice: self._shuffle = request.training_config.shuffle self._batch_size = request.training_config.batch_size + self.checkpoint_dir = ( + self.config.checkpoint_dir or f"~/.llama/checkpoints/{self.model_id}" + ) + # this is important for debugging purpose self.max_steps_per_epoch = request.training_config.max_steps_per_epoch self.global_step = 0 @@ -98,7 +108,7 @@ class LoraFinetuningSingleDevice: pth_files = [file for file in files if file.endswith(".pth")] return pth_files except FileNotFoundError: - return f"Error: The directory '{checkpoint_dir}' does not exist." + return [f"Error: The directory '{checkpoint_dir}' does not exist."] self._checkpointer = training.FullModelMetaCheckpointer( checkpoint_dir=self.config.checkpoint_dir, @@ -133,7 +143,7 @@ class LoraFinetuningSingleDevice: log.info("Tokenizer is initialized from file.") self._optimizer = self._setup_optimizer( - optimizer_config=self.request.training_config.optimizer, opt_state_dict=None + optimizer_config=self.request.training_config.optimizer ) self._loss_fn = CEWithChunkedOutputLoss() @@ -253,8 +263,8 @@ class LoraFinetuningSingleDevice: def _setup_tokenizer( self, - ) -> Tokenizer: - tokenizer_path = self.config.checkpoint_dir + "/tokenizer.model" + ) -> Llama3Tokenizer: + tokenizer_path = self.checkpoint_dir + "/tokenizer.model" tokenizer_type = utils.get_tokenizer_type(self.model_id) return tokenizer_type(path=tokenizer_path) @@ -270,13 +280,17 @@ class LoraFinetuningSingleDevice: log.info("Optimizer and loss are initialized.") return optimizer - async def _setup_data( - self, tokenizer: Tokenizer, shuffle: bool, batch_size: int + def _setup_data( + self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int ) -> Tuple[DistributedSampler, DataLoader]: - all_rows = await self.datasetio_api.get_rows_paginated( - dataset_id=self.request.dataset_id, - rows_in_page=-1, - ) + async def fetch_rows(): + return await self.datasetio_api.get_rows_paginated( + dataset_id=self.request.dataset_id, + rows_in_page=-1, + ) + + # Run the async function in an event loop + all_rows = asyncio.run(fetch_rows()) rows = all_rows.rows ds = SFTDataset( @@ -323,3 +337,198 @@ class LoraFinetuningSingleDevice: log.info("Learning rate scheduler is initialized.") return lr_scheduler + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + + adapter_state_dict = get_adapter_state_dict(self._model.state_dict()) + ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + + # Construct the full state dict with LoRA weights merged into base LLM weights + + # Move to CPU to avoid a copy on GPU + state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + + merged_state_dict = get_merged_lora_ckpt( + state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + + ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) + + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + ) + + def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + # Shape [b, s], needed for the loss not the model + labels = batch.pop("labels") + # run model + with self.activations_handling_ctx: + logits = self._model(**batch) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + loss = self._loss_fn(logits, labels) + + # free logits otherwise it peaks backward memory + del logits + + return loss + + def train(self) -> None: + """ + The core training loop. + """ + + # if self._compile: + # log.info( + # "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + # ) + + # Initialize tokens count and running loss (for grad accumulation) + # t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + # pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + # if ( + # curr_epoch == 0 + # and self.profiler_profile_memory + # and idx == self.profiler_wait_steps + self.profiler_warmup_steps + # ): + # torch.cuda.memory._record_memory_history() + + training.utils.batch_to_device(batch, self._device) + + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + + # Loss is normalized by default so we multiply by the number of tokens + # This way we can normalize by the total number of tokens if we're accumulating gradients + current_loss = self._loss_step(batch) * current_num_tokens + running_loss += current_loss + current_loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + training.scale_grads(self._model, 1 / num_tokens) + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.global_step += 1 + + # loss_to_log = running_loss.item() / num_tokens + # pbar.update(1) + # pbar.set_description( + # f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + # ) + + # Log per-step metrics + # if self.global_step % self._log_every_n_steps == 0: + # time_per_step = time.perf_counter() - t0 + # log_dict = { + # "loss": loss_to_log, + # "lr": self._optimizer.param_groups[0]["lr"], + # "tokens_per_second_per_gpu": num_tokens / time_per_step, + # } + # if self._device.type == "cuda" and self._log_peak_memory_stats: + # log_dict.update( + # training.get_memory_stats(device=self._device) + # ) + # if self._clip_grad_norm is not None: + # log_dict.update({"grad_norm": grad_norm}) + # self._metric_logger.log_dict( + # log_dict, + # step=self.global_step, + # ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + # t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + # if ( + # curr_epoch == 0 + # and self.profiler_profile_memory + # and idx + # == self.profiler_wait_steps + # + self.profiler_warmup_steps + # + self.profiler_active_steps + # ): + # torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + # prof.step() + + self.epochs_run += 1 + # start_save_checkpoint = time.perf_counter() + log.info("Starting checkpoint save...") + self.save_checkpoint(epoch=curr_epoch) + # log.info( + # "Checkpoint saved in {:.2f} seconds.".format( + # time.perf_counter() - start_save_checkpoint + # ) + # ) diff --git a/llama_stack/providers/inline/post_training/meta_reference/utils.py b/llama_stack/providers/inline/post_training/meta_reference/utils.py index 5563b2a4c..4db5ab6df 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/utils.py +++ b/llama_stack/providers/inline/post_training/meta_reference/utils.py @@ -30,15 +30,13 @@ BuildTokenizerCallable = Callable[..., Llama3Tokenizer] def get_model_type( - self, model_id: str, ) -> BuildLoraModelCallable: model = resolve_model(model_id) return LORA_MODEL_TYPES[model.core_model_id.value] -def get_tokenizer( - self, +def get_tokenizer_type( model_id: str, ) -> BuildTokenizerCallable: model = resolve_model(model_id) diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py new file mode 100644 index 000000000..6c4554bb7 --- /dev/null +++ b/llama_stack/providers/registry/post_training.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.distribution.datatypes import * # noqa: F403 + + +META_REFERENCE_DEPS = [ + "torch", + "torchtune", + "numpy", +] + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.post_training, + provider_type="inline::meta-reference", + pip_packages=META_REFERENCE_DEPS, + module="llama_stack.providers.inline.post_training.meta_reference", + config_class="llama_stack.providers.inline.post_training.meta_reference.MetaReferencePostTrainingConfig", + ), + ] diff --git a/llama_stack/templates/meta-reference-gpu/build.yaml b/llama_stack/templates/meta-reference-gpu/build.yaml index ef075d098..459a1b96c 100644 --- a/llama_stack/templates/meta-reference-gpu/build.yaml +++ b/llama_stack/templates/meta-reference-gpu/build.yaml @@ -4,6 +4,8 @@ distribution_spec: description: Use Meta Reference for running LLM inference docker_image: null providers: + post_training: + - inline::meta-reference inference: - inline::meta-reference memory: