mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
temp commit
This commit is contained in:
parent
9c1ae088f9
commit
c9a009b5e7
7 changed files with 268 additions and 53 deletions
|
@ -18,3 +18,5 @@ class Job(BaseModel):
|
||||||
class JobStatus(Enum):
|
class JobStatus(Enum):
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
in_progress = "in_progress"
|
in_progress = "in_progress"
|
||||||
|
failed = "failed"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
|
|
@ -4,13 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingMetric(BaseModel):
|
||||||
|
epoch: int
|
||||||
|
train_loss: float
|
||||||
|
validation_loss: float
|
||||||
|
perplexity: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
|
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
|
||||||
class Checkpoint(BaseModel):
|
class Checkpoint(BaseModel):
|
||||||
iters: int
|
identifier: str
|
||||||
path: URL
|
created_at: datetime
|
||||||
epoch: int
|
epoch: int
|
||||||
|
post_training_job_id: str
|
||||||
|
path: str
|
||||||
|
training_metrics: Optional[PostTrainingMetric]
|
||||||
|
|
|
@ -14,6 +14,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
@ -87,14 +88,6 @@ class PostTrainingJobLogStream(BaseModel):
|
||||||
log_lines: List[str]
|
log_lines: List[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobStatus(Enum):
|
|
||||||
running = "running"
|
|
||||||
completed = "completed"
|
|
||||||
failed = "failed"
|
|
||||||
scheduled = "scheduled"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RLHFAlgorithm(Enum):
|
class RLHFAlgorithm(Enum):
|
||||||
dpo = "dpo"
|
dpo = "dpo"
|
||||||
|
@ -139,7 +132,7 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
"""Status of a finetuning job."""
|
"""Status of a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
status: PostTrainingJobStatus
|
status: JobStatus
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: Optional[datetime] = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
|
@ -192,16 +185,10 @@ class PostTraining(Protocol):
|
||||||
@webmethod(route="/post-training/jobs")
|
@webmethod(route="/post-training/jobs")
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||||
|
|
||||||
# sends SSE stream of logs
|
|
||||||
@webmethod(route="/post-training/job/logs")
|
|
||||||
async def get_training_job_logstream(
|
|
||||||
self, job_uuid: str
|
|
||||||
) -> PostTrainingJobLogStream: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
|
@ -0,0 +1,155 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchtune import training
|
||||||
|
from torchtune.models import convert_weights
|
||||||
|
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
|
||||||
|
from torchtune.utils._logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
class TorchtuneCheckpointer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
training_algorithm: str,
|
||||||
|
checkpoint_dir: str,
|
||||||
|
checkpoint_files: List[str],
|
||||||
|
output_dir: str,
|
||||||
|
model_type: str,
|
||||||
|
) -> None:
|
||||||
|
# Fail fast if ``checkpoint_files`` is invalid
|
||||||
|
# TODO: support loading more than one file
|
||||||
|
if len(checkpoint_files) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently we only support reading from a single torchtune checkpoint file. "
|
||||||
|
f"Got {len(checkpoint_files)} files instead."
|
||||||
|
)
|
||||||
|
self._checkpoint_file = checkpoint_files[0]
|
||||||
|
self._model_id = model_id
|
||||||
|
self._training_algorithm = training_algorithm
|
||||||
|
self._checkpoint_dir = Path(checkpoint_dir)
|
||||||
|
self._model_type = ModelType[model_type]
|
||||||
|
self._output_dir = output_dir
|
||||||
|
# get ckpt paths
|
||||||
|
self._checkpoint_path = Path.joinpath(
|
||||||
|
self._checkpoint_dir, self._checkpoint_file
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_checkpoint(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||||
|
"""
|
||||||
|
state_dict: Dict[str:Any] = {}
|
||||||
|
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||||
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
llama3_vision_meta_to_tune,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||||
|
if self._model_type == ModelType.LLAMA3_2:
|
||||||
|
logger.info(
|
||||||
|
"Identified model_type = Llama3_2. Ignoring output.weight in"
|
||||||
|
" checkpoint in favor of the tok_embedding.weight"
|
||||||
|
" tied weights."
|
||||||
|
)
|
||||||
|
state_dict[training.MODEL_KEY].pop("output.weight")
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
self,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
epoch: int,
|
||||||
|
adapter_only: bool = False,
|
||||||
|
) -> str:
|
||||||
|
model_file_path = (
|
||||||
|
Path(self._output_dir)
|
||||||
|
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# copy the related files for inference
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||||
|
Path.joinpath(model_file_path, "params.json"),
|
||||||
|
)
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||||
|
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||||
|
)
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||||
|
Path.joinpath(model_file_path, "orig_params.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not adapter_only:
|
||||||
|
model_state_dict = state_dict[training.MODEL_KEY]
|
||||||
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
llama3_vision_tune_to_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||||
|
if (
|
||||||
|
self._model_type == ModelType.LLAMA3_2
|
||||||
|
and "output.weight" not in model_state_dict
|
||||||
|
):
|
||||||
|
model_state_dict["output.weight"] = model_state_dict[
|
||||||
|
"tok_embeddings.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||||
|
|
||||||
|
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Model checkpoint of size "
|
||||||
|
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {model_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if training.ADAPTER_KEY in state_dict:
|
||||||
|
adapter_file_path = model_file_path / "adapter"
|
||||||
|
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||||
|
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Adapter checkpoint of size "
|
||||||
|
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {adapter_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif adapter_only:
|
||||||
|
raise ValueError(
|
||||||
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_file_path
|
|
@ -20,6 +20,11 @@ class TorchtunePostTrainingImpl:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.datasetio_api = datasetio_api
|
self.datasetio_api = datasetio_api
|
||||||
|
|
||||||
|
# TODO: assume sync job, will need jobs API for async scheduling
|
||||||
|
self.jobs_status = {}
|
||||||
|
self.jobs_list = []
|
||||||
|
self.checkpoints_dict = {}
|
||||||
|
|
||||||
async def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
|
@ -30,23 +35,49 @@ class TorchtunePostTrainingImpl:
|
||||||
checkpoint_dir: Optional[str],
|
checkpoint_dir: Optional[str],
|
||||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||||
) -> PostTrainingJob:
|
) -> PostTrainingJob:
|
||||||
|
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
job_status_response = PostTrainingJobStatusResponse(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
status=JobStatus.scheduled,
|
||||||
|
scheduled_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.jobs_list.append(post_training_job)
|
||||||
if isinstance(algorithm_config, LoraFinetuningConfig):
|
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
recipe = LoraFinetuningSingleDevice(
|
try:
|
||||||
self.config,
|
recipe = LoraFinetuningSingleDevice(
|
||||||
training_config,
|
self.config,
|
||||||
hyperparam_search_config,
|
training_config,
|
||||||
logger_config,
|
hyperparam_search_config,
|
||||||
model,
|
logger_config,
|
||||||
checkpoint_dir,
|
model,
|
||||||
algorithm_config,
|
checkpoint_dir,
|
||||||
self.datasetio_api,
|
algorithm_config,
|
||||||
)
|
self.datasetio_api,
|
||||||
await recipe.setup()
|
)
|
||||||
await recipe.train()
|
|
||||||
|
job_status_response.status = JobStatus.in_progress
|
||||||
|
job_status_response.started_at = datetime.now()
|
||||||
|
|
||||||
|
await recipe.setup()
|
||||||
|
resources_allocated, checkpoints = await recipe.train()
|
||||||
|
|
||||||
|
self.checkpoints_dict[job_uuid] = checkpoints
|
||||||
|
job_status_response.resources_allocated = resources_allocated
|
||||||
|
job_status_response.checkpoints = checkpoints
|
||||||
|
job_status_response.status = JobStatus.completed
|
||||||
|
job_status_response.completed_at = datetime.now()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
job_status_response.status = JobStatus.failed
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
return PostTrainingJob(job_uuid=job_uuid)
|
self.jobs_status[job_uuid] = job_status_response
|
||||||
|
|
||||||
|
return post_training_job
|
||||||
|
|
||||||
async def preference_optimize(
|
async def preference_optimize(
|
||||||
self,
|
self,
|
||||||
|
@ -58,24 +89,26 @@ class TorchtunePostTrainingImpl:
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
# TODO @markchen1015 impelment below APIs
|
async def get_training_jobs(self) -> List[PostTrainingJob]:
|
||||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
return self.jobs_list
|
||||||
|
|
||||||
# sends SSE stream of logs
|
|
||||||
@webmethod(route="/post-training/job/logs")
|
|
||||||
async def get_training_job_logstream(
|
|
||||||
self, job_uuid: str
|
|
||||||
) -> PostTrainingJobLogStream: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
@webmethod(route="/post-training/job/status")
|
||||||
async def get_training_job_status(
|
async def get_training_job_status(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
) -> Optional[PostTrainingJobStatusResponse]:
|
||||||
|
if job_uuid in self.jobs_status:
|
||||||
|
return self.jobs_status[job_uuid]
|
||||||
|
return None
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel")
|
||||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
async def get_training_job_artifacts(
|
async def get_training_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
) -> PostTrainingJobArtifactsResponse:
|
||||||
|
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||||
|
return PostTrainingJobArtifactsResponse(
|
||||||
|
job_uuid=job_uuid, checkpoints=checkpoints
|
||||||
|
)
|
||||||
|
|
|
@ -13,14 +13,20 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||||
|
TorchtuneCheckpointer,
|
||||||
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchtune import utils as torchtune_utils
|
from torchtune import utils as torchtune_utils
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
from llama_stack.apis.post_training import * # noqa
|
from llama_stack.apis.post_training import * # noqa
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
from llama_stack.providers.inline.post_training.torchtune import utils
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
from llama_stack.providers.inline.post_training.torchtune.config import (
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
TorchtunePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
|
@ -99,7 +105,7 @@ class LoraFinetuningSingleDevice:
|
||||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
# TODO @markchen1015 make it work with get_training_job_artifacts
|
# TODO @markchen1015 make it work with get_training_job_artifacts
|
||||||
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||||
|
|
||||||
self.seed = training.set_seed(seed=config.torch_seed)
|
self.seed = training.set_seed(seed=config.torch_seed)
|
||||||
self.epochs_run = 0
|
self.epochs_run = 0
|
||||||
|
@ -138,7 +144,9 @@ class LoraFinetuningSingleDevice:
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||||
|
|
||||||
self._checkpointer = training.FullModelMetaCheckpointer(
|
self._checkpointer = TorchtuneCheckpointer(
|
||||||
|
model_id=self.model_id,
|
||||||
|
training_algorithm="sft",
|
||||||
checkpoint_dir=self.checkpoint_dir,
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||||
output_dir=self._output_dir,
|
output_dir=self._output_dir,
|
||||||
|
@ -148,8 +156,6 @@ class LoraFinetuningSingleDevice:
|
||||||
return checkpoint_dict
|
return checkpoint_dict
|
||||||
|
|
||||||
async def setup(self) -> None:
|
async def setup(self) -> None:
|
||||||
self._metric_logger = DiskLogger(log_dir=self._output_dir)
|
|
||||||
|
|
||||||
checkpoint_dict = await self.load_checkpoint()
|
checkpoint_dict = await self.load_checkpoint()
|
||||||
|
|
||||||
self._model = await self._setup_model(
|
self._model = await self._setup_model(
|
||||||
|
@ -419,20 +425,26 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
async def train(self) -> None:
|
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||||
"""
|
"""
|
||||||
The core training loop.
|
The core training loop.
|
||||||
"""
|
"""
|
||||||
# Initialize tokens count and running loss (for grad accumulation)
|
# Initialize tokens count and running loss (for grad accumulation)
|
||||||
# t0 = time.perf_counter()
|
|
||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
running_loss = 0
|
running_loss = 0
|
||||||
num_tokens = 0
|
num_tokens = 0
|
||||||
|
|
||||||
|
# training artifacts
|
||||||
|
checkpoints = []
|
||||||
|
memory_stats = {}
|
||||||
|
|
||||||
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||||
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||||
# in case shuffle is True
|
# in case shuffle is True
|
||||||
|
metric_logger = DiskLogger(
|
||||||
|
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||||
|
)
|
||||||
self._sampler.set_epoch(curr_epoch)
|
self._sampler.set_epoch(curr_epoch)
|
||||||
|
|
||||||
for idx, batch in enumerate(self._dataloader):
|
for idx, batch in enumerate(self._dataloader):
|
||||||
|
@ -478,10 +490,14 @@ class LoraFinetuningSingleDevice:
|
||||||
"lr": self._optimizer.param_groups[0]["lr"],
|
"lr": self._optimizer.param_groups[0]["lr"],
|
||||||
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||||
}
|
}
|
||||||
log_dict.update(training.get_memory_stats(device=self._device))
|
|
||||||
|
memory_stats = training.get_memory_stats(device=self._device)
|
||||||
|
log_dict.update(memory_stats)
|
||||||
|
|
||||||
if self._clip_grad_norm is not None:
|
if self._clip_grad_norm is not None:
|
||||||
log_dict.update({"grad_norm": grad_norm})
|
log_dict.update({"grad_norm": grad_norm})
|
||||||
self._metric_logger.log_dict(
|
|
||||||
|
metric_logger.log_dict(
|
||||||
log_dict,
|
log_dict,
|
||||||
step=self.global_step,
|
step=self.global_step,
|
||||||
)
|
)
|
||||||
|
@ -493,4 +509,13 @@ class LoraFinetuningSingleDevice:
|
||||||
|
|
||||||
self.epochs_run += 1
|
self.epochs_run += 1
|
||||||
log.info("Starting checkpoint save...")
|
log.info("Starting checkpoint save...")
|
||||||
await self.save_checkpoint(epoch=curr_epoch)
|
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
epoch=curr_epoch,
|
||||||
|
path=checkpoint_path,
|
||||||
|
)
|
||||||
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
|
return (memory_stats, checkpoints)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue