forked from phoenix-oss/llama-stack-mirror
[2/n][torchtune integration] implement job management and return training artifacts (#593)
### Context In this PR, we - Implement the post training job management and get training artifacts apis - get_training_jobs - get_training_job_status - get_training_job_artifacts - get_training_job_logstream is deleted since the trace can be directly accessed by UI with Jaeger https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces - Refactor the post training and training types definition to make them more intuitive. - Rewrite the checkpointer to make it compatible with llama-stack file system and can be recognized during inference ### Test Unit test `pytest llama_stack/providers/tests/post_training/test_post_training.py -m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short --disable-warnings` <img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM" src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a"> e2e test with client side call <img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM" src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
parent
5764a95912
commit
c294a01c4b
8 changed files with 331 additions and 67 deletions
|
@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||
TorchtuneCheckpointer,
|
||||
)
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from llama_stack.providers.inline.post_training.torchtune import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
|
@ -62,16 +68,22 @@ class LoraFinetuningSingleDevice:
|
|||
def __init__(
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
job_uuid: str,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
algorithm_config: Optional[AlgorithmConfig],
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||
raise ValueError(
|
||||
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
||||
)
|
||||
self.algorithm_config = algorithm_config
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||
|
@ -99,8 +111,7 @@ class LoraFinetuningSingleDevice:
|
|||
model = resolve_model(self.model_id)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
# TODO @SLR722 make it work with get_training_job_artifacts
|
||||
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.epochs_run = 0
|
||||
|
@ -140,7 +151,9 @@ class LoraFinetuningSingleDevice:
|
|||
except FileNotFoundError:
|
||||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||
|
||||
self._checkpointer = training.FullModelMetaCheckpointer(
|
||||
self._checkpointer = TorchtuneCheckpointer(
|
||||
model_id=self.model_id,
|
||||
training_algorithm="sft",
|
||||
checkpoint_dir=self.checkpoint_dir,
|
||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||
output_dir=self._output_dir,
|
||||
|
@ -150,8 +163,6 @@ class LoraFinetuningSingleDevice:
|
|||
return checkpoint_dict
|
||||
|
||||
async def setup(self) -> None:
|
||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
||||
|
||||
checkpoint_dict = await self.load_checkpoint()
|
||||
|
||||
self._model = await self._setup_model(
|
||||
|
@ -370,7 +381,7 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
return lr_scheduler
|
||||
|
||||
async def save_checkpoint(self, epoch: int) -> None:
|
||||
async def save_checkpoint(self, epoch: int) -> str:
|
||||
ckpt_dict = {}
|
||||
|
||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||
|
@ -400,7 +411,7 @@ class LoraFinetuningSingleDevice:
|
|||
}
|
||||
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||
|
||||
self._checkpointer.save_checkpoint(
|
||||
return self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
)
|
||||
|
@ -429,20 +440,26 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
async def train(self) -> None:
|
||||
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
# t0 = time.perf_counter()
|
||||
t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
|
||||
# training artifacts
|
||||
checkpoints = []
|
||||
memory_stats = {}
|
||||
|
||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
metric_logger = DiskLogger(
|
||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||
)
|
||||
self._sampler.set_epoch(curr_epoch)
|
||||
|
||||
for idx, batch in enumerate(self._dataloader):
|
||||
|
@ -488,10 +505,14 @@ class LoraFinetuningSingleDevice:
|
|||
"lr": self._optimizer.param_groups[0]["lr"],
|
||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
}
|
||||
log_dict.update(training.get_memory_stats(device=self._device))
|
||||
|
||||
memory_stats = training.get_memory_stats(device=self._device)
|
||||
log_dict.update(memory_stats)
|
||||
|
||||
if self._clip_grad_norm is not None:
|
||||
log_dict.update({"grad_norm": grad_norm})
|
||||
self._metric_logger.log_dict(
|
||||
|
||||
metric_logger.log_dict(
|
||||
log_dict,
|
||||
step=self.global_step,
|
||||
)
|
||||
|
@ -503,4 +524,14 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
self.epochs_run += 1
|
||||
log.info("Starting checkpoint save...")
|
||||
await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||
checkpoint = Checkpoint(
|
||||
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||
created_at=datetime.now(),
|
||||
epoch=curr_epoch,
|
||||
post_training_job_id=self.job_uuid,
|
||||
path=checkpoint_path,
|
||||
)
|
||||
checkpoints.append(checkpoint)
|
||||
|
||||
return (memory_stats, checkpoints)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue