diff --git a/llama_stack/apis/common/job_types.py b/llama_stack/apis/common/job_types.py index ab8ab22dc..c945bd8ff 100644 --- a/llama_stack/apis/common/job_types.py +++ b/llama_stack/apis/common/job_types.py @@ -18,3 +18,5 @@ class Job(BaseModel): class JobStatus(Enum): completed = "completed" in_progress = "in_progress" + failed = "failed" + scheduled = "scheduled" diff --git a/llama_stack/apis/common/training_types.py b/llama_stack/apis/common/training_types.py index fd74293eb..670877c8f 100644 --- a/llama_stack/apis/common/training_types.py +++ b/llama_stack/apis/common/training_types.py @@ -4,13 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_models.llama3.api.datatypes import URL +from datetime import datetime +from typing import Optional + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel +@json_schema_type +class PostTrainingMetric(BaseModel): + epoch: int + train_loss: float + validation_loss: float + perplexity: float + + @json_schema_type(schema={"description": "Checkpoint created during training runs"}) class Checkpoint(BaseModel): - iters: int - path: URL + identifier: str + created_at: datetime epoch: int + post_training_job_id: str + path: str + training_metrics: Optional[PostTrainingMetric] diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 758817ac0..b77747e3f 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -14,6 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from llama_models.llama3.api.datatypes import * # noqa: F403 +from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403 @@ -87,14 +88,6 @@ class PostTrainingJobLogStream(BaseModel): log_lines: List[str] -@json_schema_type -class PostTrainingJobStatus(Enum): - running = "running" - completed = "completed" - failed = "failed" - scheduled = "scheduled" - - @json_schema_type class RLHFAlgorithm(Enum): dpo = "dpo" @@ -139,7 +132,7 @@ class PostTrainingJobStatusResponse(BaseModel): """Status of a finetuning job.""" job_uuid: str - status: PostTrainingJobStatus + status: JobStatus scheduled_at: Optional[datetime] = None started_at: Optional[datetime] = None @@ -192,16 +185,10 @@ class PostTraining(Protocol): @webmethod(route="/post-training/jobs") async def get_training_jobs(self) -> List[PostTrainingJob]: ... - # sends SSE stream of logs - @webmethod(route="/post-training/job/logs") - async def get_training_job_logstream( - self, job_uuid: str - ) -> PostTrainingJobLogStream: ... - @webmethod(route="/post-training/job/status") async def get_training_job_status( self, job_uuid: str - ) -> PostTrainingJobStatusResponse: ... + ) -> Optional[PostTrainingJobStatusResponse]: ... @webmethod(route="/post-training/job/cancel") async def cancel_training_job(self, job_uuid: str) -> None: ... diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py new file mode 100644 index 000000000..9642f6ecc --- /dev/null +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import shutil +from pathlib import Path +from typing import Any, Dict, List + +import torch +from torchtune import training +from torchtune.models import convert_weights +from torchtune.training.checkpointing._utils import ModelType, safe_torch_load +from torchtune.utils._logging import get_logger + +logger = get_logger("DEBUG") + + +class TorchtuneCheckpointer: + def __init__( + self, + model_id: str, + training_algorithm: str, + checkpoint_dir: str, + checkpoint_files: List[str], + output_dir: str, + model_type: str, + ) -> None: + # Fail fast if ``checkpoint_files`` is invalid + # TODO: support loading more than one file + if len(checkpoint_files) != 1: + raise ValueError( + "Currently we only support reading from a single torchtune checkpoint file. " + f"Got {len(checkpoint_files)} files instead." + ) + self._checkpoint_file = checkpoint_files[0] + self._model_id = model_id + self._training_algorithm = training_algorithm + self._checkpoint_dir = Path(checkpoint_dir) + self._model_type = ModelType[model_type] + self._output_dir = output_dir + # get ckpt paths + self._checkpoint_path = Path.joinpath( + self._checkpoint_dir, self._checkpoint_file + ) + + def load_checkpoint(self) -> Dict[str, Any]: + """ + Load Meta checkpoint from file. Currently only loading from a single file is supported. + """ + state_dict: Dict[str:Any] = {} + model_state_dict = safe_torch_load(self._checkpoint_path) + if self._model_type == ModelType.LLAMA3_VISION: + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_meta_to_tune, + ) + + state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune( + model_state_dict + ) + else: + state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune( + model_state_dict + ) + + # llama3_2 has tied weights, so we need to remove the output.weight key + if self._model_type == ModelType.LLAMA3_2: + logger.info( + "Identified model_type = Llama3_2. Ignoring output.weight in" + " checkpoint in favor of the tok_embedding.weight" + " tied weights." + ) + state_dict[training.MODEL_KEY].pop("output.weight") + + return state_dict + + def save_checkpoint( + self, + state_dict: Dict[str, Any], + epoch: int, + adapter_only: bool = False, + ) -> str: + model_file_path = ( + Path(self._output_dir) + / f"{self._model_id}-{self._training_algorithm}-{epoch}" + ) + + model_file_path.mkdir(parents=True, exist_ok=True) + + # copy the related files for inference + shutil.copy( + Path.joinpath(self._checkpoint_dir, "params.json"), + Path.joinpath(model_file_path, "params.json"), + ) + shutil.copy( + Path.joinpath(self._checkpoint_dir, "tokenizer.model"), + Path.joinpath(model_file_path, "tokenizer.model"), + ) + shutil.copy( + Path.joinpath(self._checkpoint_dir, "orig_params.json"), + Path.joinpath(model_file_path, "orig_params.json"), + ) + + if not adapter_only: + model_state_dict = state_dict[training.MODEL_KEY] + if self._model_type == ModelType.LLAMA3_VISION: + from torchtune.models.llama3_2_vision._convert_weights import ( + llama3_vision_tune_to_meta, + ) + + state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta( + model_state_dict + ) + else: + # llama3_2 has tied weights, so we need to add the output.weight key + if ( + self._model_type == ModelType.LLAMA3_2 + and "output.weight" not in model_state_dict + ): + model_state_dict["output.weight"] = model_state_dict[ + "tok_embeddings.weight" + ] + + state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta( + model_state_dict + ) + + model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth") + + torch.save(state_dict[training.MODEL_KEY], model_file_name) + logger.info( + "Model checkpoint of size " + f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB " + f"saved to {model_file_name}" + ) + + if training.ADAPTER_KEY in state_dict: + adapter_file_path = model_file_path / "adapter" + adapter_file_path.mkdir(parents=True, exist_ok=True) + adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth") + torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name) + logger.info( + "Adapter checkpoint of size " + f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB " + f"saved to {adapter_file_name}" + ) + + elif adapter_only: + raise ValueError( + "Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights." + ) + + return model_file_path diff --git a/llama_stack/providers/inline/post_training/torchtune/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py similarity index 100% rename from llama_stack/providers/inline/post_training/torchtune/utils.py rename to llama_stack/providers/inline/post_training/torchtune/common/utils.py diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index f33ca059a..1144d75c5 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -20,6 +20,11 @@ class TorchtunePostTrainingImpl: self.config = config self.datasetio_api = datasetio_api + # TODO: assume sync job, will need jobs API for async scheduling + self.jobs_status = {} + self.jobs_list = [] + self.checkpoints_dict = {} + async def supervised_fine_tune( self, job_uuid: str, @@ -30,23 +35,49 @@ class TorchtunePostTrainingImpl: checkpoint_dir: Optional[str], algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], ) -> PostTrainingJob: + post_training_job = PostTrainingJob(job_uuid=job_uuid) + + job_status_response = PostTrainingJobStatusResponse( + job_uuid=job_uuid, + status=JobStatus.scheduled, + scheduled_at=datetime.now(), + ) + + self.jobs_list.append(post_training_job) if isinstance(algorithm_config, LoraFinetuningConfig): - recipe = LoraFinetuningSingleDevice( - self.config, - training_config, - hyperparam_search_config, - logger_config, - model, - checkpoint_dir, - algorithm_config, - self.datasetio_api, - ) - await recipe.setup() - await recipe.train() + try: + recipe = LoraFinetuningSingleDevice( + self.config, + training_config, + hyperparam_search_config, + logger_config, + model, + checkpoint_dir, + algorithm_config, + self.datasetio_api, + ) + + job_status_response.status = JobStatus.in_progress + job_status_response.started_at = datetime.now() + + await recipe.setup() + resources_allocated, checkpoints = await recipe.train() + + self.checkpoints_dict[job_uuid] = checkpoints + job_status_response.resources_allocated = resources_allocated + job_status_response.checkpoints = checkpoints + job_status_response.status = JobStatus.completed + job_status_response.completed_at = datetime.now() + + except Exception: + job_status_response.status = JobStatus.failed + raise else: raise NotImplementedError() - return PostTrainingJob(job_uuid=job_uuid) + self.jobs_status[job_uuid] = job_status_response + + return post_training_job async def preference_optimize( self, @@ -58,24 +89,26 @@ class TorchtunePostTrainingImpl: logger_config: Dict[str, Any], ) -> PostTrainingJob: ... - # TODO @markchen1015 impelment below APIs - async def get_training_jobs(self) -> List[PostTrainingJob]: ... - - # sends SSE stream of logs - @webmethod(route="/post-training/job/logs") - async def get_training_job_logstream( - self, job_uuid: str - ) -> PostTrainingJobLogStream: ... + async def get_training_jobs(self) -> List[PostTrainingJob]: + return self.jobs_list @webmethod(route="/post-training/job/status") async def get_training_job_status( self, job_uuid: str - ) -> PostTrainingJobStatusResponse: ... + ) -> Optional[PostTrainingJobStatusResponse]: + if job_uuid in self.jobs_status: + return self.jobs_status[job_uuid] + return None @webmethod(route="/post-training/job/cancel") - async def cancel_training_job(self, job_uuid: str) -> None: ... + async def cancel_training_job(self, job_uuid: str) -> None: + raise NotImplementedError("Job cancel is not implemented yet") @webmethod(route="/post-training/job/artifacts") async def get_training_job_artifacts( self, job_uuid: str - ) -> PostTrainingJobArtifactsResponse: ... + ) -> PostTrainingJobArtifactsResponse: + checkpoints = self.checkpoints_dict.get(job_uuid, []) + return PostTrainingJobArtifactsResponse( + job_uuid=job_uuid, checkpoints=checkpoints + ) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 17d3cbc2c..5b8b8e30f 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple import torch from llama_models.sku_list import resolve_model + from llama_stack.apis.datasetio import DatasetIO + +from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR +from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( + TorchtuneCheckpointer, +) from torch import nn from torchtune import utils as torchtune_utils from torchtune.training.metric_logging import DiskLogger from llama_stack.apis.post_training import * # noqa from llama_stack.distribution.utils.model_utils import model_local_dir -from llama_stack.providers.inline.post_training.torchtune import utils +from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) @@ -99,7 +105,7 @@ class LoraFinetuningSingleDevice: self.checkpoint_dir = model_checkpoint_dir(model) # TODO @markchen1015 make it work with get_training_job_artifacts - self._output_dir = self.checkpoint_dir + "/posting_training/" + self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self.seed = training.set_seed(seed=config.torch_seed) self.epochs_run = 0 @@ -138,7 +144,9 @@ class LoraFinetuningSingleDevice: except FileNotFoundError: return [f"Error: The directory '{checkpoint_dir}' does not exist."] - self._checkpointer = training.FullModelMetaCheckpointer( + self._checkpointer = TorchtuneCheckpointer( + model_id=self.model_id, + training_algorithm="sft", checkpoint_dir=self.checkpoint_dir, checkpoint_files=get_checkpoint_files(self.checkpoint_dir), output_dir=self._output_dir, @@ -148,8 +156,6 @@ class LoraFinetuningSingleDevice: return checkpoint_dict async def setup(self) -> None: - self._metric_logger = DiskLogger(log_dir=self._output_dir) - checkpoint_dict = await self.load_checkpoint() self._model = await self._setup_model( @@ -419,20 +425,26 @@ class LoraFinetuningSingleDevice: return loss - async def train(self) -> None: + async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]: """ The core training loop. """ # Initialize tokens count and running loss (for grad accumulation) - # t0 = time.perf_counter() t0 = time.perf_counter() running_loss = 0 num_tokens = 0 + # training artifacts + checkpoints = [] + memory_stats = {} + # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): # Update the sampler to ensure data is correctly shuffled across epochs # in case shuffle is True + metric_logger = DiskLogger( + log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}" + ) self._sampler.set_epoch(curr_epoch) for idx, batch in enumerate(self._dataloader): @@ -478,10 +490,14 @@ class LoraFinetuningSingleDevice: "lr": self._optimizer.param_groups[0]["lr"], "tokens_per_second_per_gpu": num_tokens / time_per_step, } - log_dict.update(training.get_memory_stats(device=self._device)) + + memory_stats = training.get_memory_stats(device=self._device) + log_dict.update(memory_stats) + if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict( + + metric_logger.log_dict( log_dict, step=self.global_step, ) @@ -493,4 +509,13 @@ class LoraFinetuningSingleDevice: self.epochs_run += 1 log.info("Starting checkpoint save...") - await self.save_checkpoint(epoch=curr_epoch) + checkpoint_path = await self.save_checkpoint(epoch=curr_epoch) + checkpoint = Checkpoint( + identifier=f"{self.model_id}-sft-{curr_epoch}", + created_at=datetime.now(), + epoch=curr_epoch, + path=checkpoint_path, + ) + checkpoints.append(checkpoint) + + return (memory_stats, checkpoints)