temp commit

This commit is contained in:
Botao Chen 2024-12-09 20:24:30 -08:00
parent 9c1ae088f9
commit c9a009b5e7
7 changed files with 268 additions and 53 deletions

View file

@ -18,3 +18,5 @@ class Job(BaseModel):
class JobStatus(Enum):
completed = "completed"
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_models.llama3.api.datatypes import URL
from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
@json_schema_type
class PostTrainingMetric(BaseModel):
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel):
iters: int
path: URL
identifier: str
created_at: datetime
epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric]

View file

@ -14,6 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
@ -87,14 +88,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type
class RLHFAlgorithm(Enum):
dpo = "dpo"
@ -139,7 +132,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job."""
job_uuid: str
status: PostTrainingJobStatus
status: JobStatus
scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None
@ -192,16 +185,10 @@ class PostTraining(Protocol):
@webmethod(route="/post-training/jobs")
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ...

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import torch
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
return model_file_path

View file

@ -20,6 +20,11 @@ class TorchtunePostTrainingImpl:
self.config = config
self.datasetio_api = datasetio_api
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs_status = {}
self.jobs_list = []
self.checkpoints_dict = {}
async def supervised_fine_tune(
self,
job_uuid: str,
@ -30,23 +35,49 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
) -> PostTrainingJob:
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
)
self.jobs_list.append(post_training_job)
if isinstance(algorithm_config, LoraFinetuningConfig):
recipe = LoraFinetuningSingleDevice(
self.config,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
)
await recipe.setup()
await recipe.train()
try:
recipe = LoraFinetuningSingleDevice(
self.config,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
except Exception:
job_status_response.status = JobStatus.failed
raise
else:
raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid)
self.jobs_status[job_uuid] = job_status_response
return post_training_job
async def preference_optimize(
self,
@ -58,24 +89,26 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
# TODO @markchen1015 impelment below APIs
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
async def get_training_jobs(self) -> List[PostTrainingJob]:
return self.jobs_list
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> PostTrainingJobStatusResponse: ...
) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ...
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ...
) -> PostTrainingJobArtifactsResponse:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)

View file

@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
@ -99,7 +105,7 @@ class LoraFinetuningSingleDevice:
self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @markchen1015 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/"
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
@ -138,7 +144,9 @@ class LoraFinetuningSingleDevice:
except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer(
self._checkpointer = TorchtuneCheckpointer(
model_id=self.model_id,
training_algorithm="sft",
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
@ -148,8 +156,6 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict
async def setup(self) -> None:
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
@ -419,20 +425,26 @@ class LoraFinetuningSingleDevice:
return loss
async def train(self) -> None:
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
"""
The core training loop.
"""
# Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
self._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader):
@ -478,10 +490,14 @@ class LoraFinetuningSingleDevice:
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
log_dict.update(training.get_memory_stats(device=self._device))
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
metric_logger.log_dict(
log_dict,
step=self.global_step,
)
@ -493,4 +509,13 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1
log.info("Starting checkpoint save...")
await self.save_checkpoint(epoch=curr_epoch)
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
path=checkpoint_path,
)
checkpoints.append(checkpoint)
return (memory_stats, checkpoints)