[2/n][torchtune integration] implement job management and return training artifacts (#593)

### Context 
In this PR, we 
- Implement the post training job management and get training artifacts
apis
  - get_training_jobs
  - get_training_job_status
  - get_training_job_artifacts
- get_training_job_logstream is deleted since the trace can be directly
accessed by UI with Jaeger
https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces
- Refactor the post training and training types definition to make them
more intuitive.
- Rewrite the checkpointer to make it compatible with llama-stack file
system and can be recognized during inference


### Test
Unit test
`pytest llama_stack/providers/tests/post_training/test_post_training.py
-m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short
--disable-warnings`

<img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM"
src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a">


e2e test with client side call

<img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM"
src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
Botao Chen 2024-12-13 15:00:04 -08:00 committed by GitHub
parent 5764a95912
commit c294a01c4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 331 additions and 67 deletions

View file

@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
@ -62,16 +68,22 @@ class LoraFinetuningSingleDevice:
def __init__(
self,
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
algorithm_config: Optional[AlgorithmConfig],
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -99,8 +111,7 @@ class LoraFinetuningSingleDevice:
model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @SLR722 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/"
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
@ -140,7 +151,9 @@ class LoraFinetuningSingleDevice:
except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer(
self._checkpointer = TorchtuneCheckpointer(
model_id=self.model_id,
training_algorithm="sft",
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
@ -150,8 +163,6 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict
async def setup(self) -> None:
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
@ -370,7 +381,7 @@ class LoraFinetuningSingleDevice:
)
return lr_scheduler
async def save_checkpoint(self, epoch: int) -> None:
async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
@ -400,7 +411,7 @@ class LoraFinetuningSingleDevice:
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
)
@ -429,20 +440,26 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> None:
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
"""
The core training loop.
"""
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
self._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
@ -488,10 +505,14 @@ class LoraFinetuningSingleDevice:
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
log_dict.update(training.get_memory_stats(device=self._device))
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
metric_logger.log_dict(
log_dict,
step=self.global_step,
)
@ -503,4 +524,14 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1
log.info("Starting checkpoint save...")
await self.save_checkpoint(epoch=curr_epoch)
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
)
checkpoints.append(checkpoint)
return (memory_stats, checkpoints)