[2/n][torchtune integration] implement job management and return training artifacts (#593)

### Context 
In this PR, we 
- Implement the post training job management and get training artifacts
apis
  - get_training_jobs
  - get_training_job_status
  - get_training_job_artifacts
- get_training_job_logstream is deleted since the trace can be directly
accessed by UI with Jaeger
https://llama-stack.readthedocs.io/en/latest/building_applications/telemetry.html#jaeger-to-visualize-traces
- Refactor the post training and training types definition to make them
more intuitive.
- Rewrite the checkpointer to make it compatible with llama-stack file
system and can be recognized during inference


### Test
Unit test
`pytest llama_stack/providers/tests/post_training/test_post_training.py
-m "torchtune_post_training_huggingface_datasetio" -v -s --tb=short
--disable-warnings`

<img width="1506" alt="Screenshot 2024-12-10 at 4 06 17 PM"
src="https://github.com/user-attachments/assets/16225029-bdb7-48c4-9d13-e580cc769c0a">


e2e test with client side call

<img width="888" alt="Screenshot 2024-12-10 at 4 09 44 PM"
src="https://github.com/user-attachments/assets/de375e4c-ef67-4dcc-a045-4037d9489191">
This commit is contained in:
Botao Chen 2024-12-13 15:00:04 -08:00 committed by GitHub
parent 5764a95912
commit c294a01c4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 331 additions and 67 deletions

View file

@ -18,3 +18,5 @@ class Job(BaseModel):
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import URL
from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class PostTrainingMetric(BaseModel):
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel):
iters: int
path: URL
identifier: str
created_at: datetime
epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric] = None

View file

@ -6,6 +6,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
@ -14,6 +15,7 @@ from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@ -64,6 +66,7 @@ class TrainingConfig(BaseModel):
@json_schema_type
class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str]
apply_lora_to_mlp: bool
apply_lora_to_output: bool
@ -75,12 +78,13 @@ class LoraFinetuningConfig(BaseModel):
@json_schema_type
class QATFinetuningConfig(BaseModel):
type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
AlgorithmConfig = Annotated[
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
]
@ -92,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@ -144,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: PostTrainingJobStatus
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
@ -166,7 +162,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune")
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
async def supervised_fine_tune(
self,
job_uuid: str,
@ -181,7 +177,7 @@ class PostTraining(Protocol):
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize")
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
self,
job_uuid: str,
@ -192,24 +188,18 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs")
@webmethod(route="/post-training/jobs", method="GET")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
@webmethod(route="/post-training/job/status", method="GET")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel")
@webmethod(route="/post-training/job/cancel", method="POST")
async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts")
@webmethod(route="/post-training/job/artifacts", method="GET")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
) -> Optional[PostTrainingJobArtifactsResponse]: ...

View file

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import torch
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
print("model_file_path", str(model_file_path))
return str(model_file_path)

View file

@ -24,6 +24,11 @@ class TorchtunePostTrainingImpl:
self.datasetio_api = datasetio_api
self.datasets_api = datasets
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs_status = {}
self.jobs_list = []
self.checkpoints_dict = {}
async def supervised_fine_tune(
self,
job_uuid: str,
@ -32,26 +37,57 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
for job in self.jobs_list:
if job_uuid == job.job_uuid:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
)
self.jobs_list.append(post_training_job)
if isinstance(algorithm_config, LoraFinetuningConfig):
recipe = LoraFinetuningSingleDevice(
self.config,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
self.datasets_api,
)
await recipe.setup()
await recipe.train()
try:
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
except Exception:
job_status_response.status = JobStatus.failed
raise
else:
raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid)
self.jobs_status[job_uuid] = job_status_response
return post_training_job
async def preference_optimize(
self,
@ -63,24 +99,28 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
# TODO @SLR722 impelment below APIs
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
async def get_training_jobs(self) -> List[PostTrainingJob]:
return self.jobs_list
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)
return None

View file

@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
@ -62,16 +68,22 @@ class LoraFinetuningSingleDevice:
def __init__(
self,
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
algorithm_config: Optional[AlgorithmConfig],
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
@ -99,8 +111,7 @@ class LoraFinetuningSingleDevice:
model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @SLR722 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/"
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
@ -140,7 +151,9 @@ class LoraFinetuningSingleDevice:
except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer(
self._checkpointer = TorchtuneCheckpointer(
model_id=self.model_id,
training_algorithm="sft",
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
@ -150,8 +163,6 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict
async def setup(self) -> None:
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
@ -370,7 +381,7 @@ class LoraFinetuningSingleDevice:
)
return lr_scheduler
async def save_checkpoint(self, epoch: int) -> None:
async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
@ -400,7 +411,7 @@ class LoraFinetuningSingleDevice:
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
self._checkpointer.save_checkpoint(
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
)
@ -429,20 +440,26 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> None:
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
"""
The core training loop.
"""
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
self._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
@ -488,10 +505,14 @@ class LoraFinetuningSingleDevice:
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
log_dict.update(training.get_memory_stats(device=self._device))
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
metric_logger.log_dict(
log_dict,
step=self.global_step,
)
@ -503,4 +524,14 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1
log.info("Starting checkpoint save...")
await self.save_checkpoint(epoch=curr_epoch)
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
)
checkpoints.append(checkpoint)
return (memory_stats, checkpoints)

View file

@ -19,6 +19,7 @@ class TestPostTraining:
@pytest.mark.asyncio
async def test_supervised_fine_tune(self, post_training_stack):
algorithm_config = LoraFinetuningConfig(
type="LoRA",
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
@ -59,3 +60,33 @@ class TestPostTraining:
)
assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_jobs(self, post_training_stack):
post_training_impl = post_training_stack
jobs_list = await post_training_impl.get_training_jobs()
assert isinstance(jobs_list, List)
assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@pytest.mark.asyncio
async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
assert job_artifacts.checkpoints[0].epoch == 0
assert (
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
in job_artifacts.checkpoints[0].path
)