temp commit

This commit is contained in:
Botao Chen 2024-12-09 20:24:30 -08:00
parent 9c1ae088f9
commit c9a009b5e7
7 changed files with 268 additions and 53 deletions

View file

@ -18,3 +18,5 @@ class Job(BaseModel):
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.llama3.api.datatypes import URL from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
@json_schema_type
class PostTrainingMetric(BaseModel):
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"}) @json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel): class Checkpoint(BaseModel):
iters: int identifier: str
path: URL created_at: datetime
epoch: int epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric]

View file

@ -14,6 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
@ -87,14 +88,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str] log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type @json_schema_type
class RLHFAlgorithm(Enum): class RLHFAlgorithm(Enum):
dpo = "dpo" dpo = "dpo"
@ -139,7 +132,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job.""" """Status of a finetuning job."""
job_uuid: str job_uuid: str
status: PostTrainingJobStatus status: JobStatus
scheduled_at: Optional[datetime] = None scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
@ -192,16 +185,10 @@ class PostTraining(Protocol):
@webmethod(route="/post-training/jobs") @webmethod(route="/post-training/jobs")
async def get_training_jobs(self) -> List[PostTrainingJob]: ... async def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
async def get_training_job_status( async def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import torch
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
return model_file_path

View file

@ -20,6 +20,11 @@ class TorchtunePostTrainingImpl:
self.config = config self.config = config
self.datasetio_api = datasetio_api self.datasetio_api = datasetio_api
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs_status = {}
self.jobs_list = []
self.checkpoints_dict = {}
async def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
@ -30,23 +35,49 @@ class TorchtunePostTrainingImpl:
checkpoint_dir: Optional[str], checkpoint_dir: Optional[str],
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
) -> PostTrainingJob: ) -> PostTrainingJob:
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
)
self.jobs_list.append(post_training_job)
if isinstance(algorithm_config, LoraFinetuningConfig): if isinstance(algorithm_config, LoraFinetuningConfig):
recipe = LoraFinetuningSingleDevice( try:
self.config, recipe = LoraFinetuningSingleDevice(
training_config, self.config,
hyperparam_search_config, training_config,
logger_config, hyperparam_search_config,
model, logger_config,
checkpoint_dir, model,
algorithm_config, checkpoint_dir,
self.datasetio_api, algorithm_config,
) self.datasetio_api,
await recipe.setup() )
await recipe.train()
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
except Exception:
job_status_response.status = JobStatus.failed
raise
else: else:
raise NotImplementedError() raise NotImplementedError()
return PostTrainingJob(job_uuid=job_uuid) self.jobs_status[job_uuid] = job_status_response
return post_training_job
async def preference_optimize( async def preference_optimize(
self, self,
@ -58,24 +89,26 @@ class TorchtunePostTrainingImpl:
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
# TODO @markchen1015 impelment below APIs async def get_training_jobs(self) -> List[PostTrainingJob]:
async def get_training_jobs(self) -> List[PostTrainingJob]: ... return self.jobs_list
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
async def get_training_job_logstream(
self, job_uuid: str
) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status") @webmethod(route="/post-training/job/status")
async def get_training_job_status( async def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts( async def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> PostTrainingJobArtifactsResponse:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)

View file

@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn from torch import nn
from torchtune import utils as torchtune_utils from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger from torchtune.training.metric_logging import DiskLogger
from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune import utils from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.config import ( from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig, TorchtunePostTrainingConfig,
) )
@ -99,7 +105,7 @@ class LoraFinetuningSingleDevice:
self.checkpoint_dir = model_checkpoint_dir(model) self.checkpoint_dir = model_checkpoint_dir(model)
# TODO @markchen1015 make it work with get_training_job_artifacts # TODO @markchen1015 make it work with get_training_job_artifacts
self._output_dir = self.checkpoint_dir + "/posting_training/" self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self.seed = training.set_seed(seed=config.torch_seed) self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0 self.epochs_run = 0
@ -138,7 +144,9 @@ class LoraFinetuningSingleDevice:
except FileNotFoundError: except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."] return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = training.FullModelMetaCheckpointer( self._checkpointer = TorchtuneCheckpointer(
model_id=self.model_id,
training_algorithm="sft",
checkpoint_dir=self.checkpoint_dir, checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir), checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir, output_dir=self._output_dir,
@ -148,8 +156,6 @@ class LoraFinetuningSingleDevice:
return checkpoint_dict return checkpoint_dict
async def setup(self) -> None: async def setup(self) -> None:
self._metric_logger = DiskLogger(log_dir=self._output_dir)
checkpoint_dict = await self.load_checkpoint() checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model( self._model = await self._setup_model(
@ -419,20 +425,26 @@ class LoraFinetuningSingleDevice:
return loss return loss
async def train(self) -> None: async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
""" """
The core training loop. The core training loop.
""" """
# Initialize tokens count and running loss (for grad accumulation) # Initialize tokens count and running loss (for grad accumulation)
# t0 = time.perf_counter()
t0 = time.perf_counter() t0 = time.perf_counter()
running_loss = 0 running_loss = 0
num_tokens = 0 num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint # self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs): for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs # Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True # in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
self._sampler.set_epoch(curr_epoch) self._sampler.set_epoch(curr_epoch)
for idx, batch in enumerate(self._dataloader): for idx, batch in enumerate(self._dataloader):
@ -478,10 +490,14 @@ class LoraFinetuningSingleDevice:
"lr": self._optimizer.param_groups[0]["lr"], "lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step, "tokens_per_second_per_gpu": num_tokens / time_per_step,
} }
log_dict.update(training.get_memory_stats(device=self._device))
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
if self._clip_grad_norm is not None: if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm}) log_dict.update({"grad_norm": grad_norm})
self._metric_logger.log_dict(
metric_logger.log_dict(
log_dict, log_dict,
step=self.global_step, step=self.global_step,
) )
@ -493,4 +509,13 @@ class LoraFinetuningSingleDevice:
self.epochs_run += 1 self.epochs_run += 1
log.info("Starting checkpoint save...") log.info("Starting checkpoint save...")
await self.save_checkpoint(epoch=curr_epoch) checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
path=checkpoint_path,
)
checkpoints.append(checkpoint)
return (memory_stats, checkpoints)