mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
temp commit
This commit is contained in:
parent
6c709abc4d
commit
79c525be94
5 changed files with 115 additions and 163 deletions
|
@ -41,7 +41,7 @@ class TrainingConfig(BaseModel):
|
|||
gradient_accumulation_steps: int
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
# n_iters: int
|
||||
optimizer_config: OptimizerConfig
|
||||
|
||||
enable_activation_checkpointing: bool
|
||||
memory_efficient_fsdp_wrap: Optional[bool]
|
||||
|
@ -63,6 +63,7 @@ class LoraFinetuningConfig(BaseModel):
|
|||
apply_lora_to_output: bool
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -116,7 +117,6 @@ class PostTrainingSFTRequest(BaseModel):
|
|||
|
||||
algorithm: FinetuningAlgorithm
|
||||
algorithm_config: LoraFinetuningConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
training_config: TrainingConfig
|
||||
|
||||
# TODO: define these
|
||||
|
@ -178,7 +178,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
|||
|
||||
class PostTraining(Protocol):
|
||||
@webmethod(route="/post-training/supervised-fine-tune")
|
||||
def supervised_fine_tune(
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
|
@ -186,14 +186,14 @@ class PostTraining(Protocol):
|
|||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
# optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize")
|
||||
def preference_optimize(
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: URL,
|
||||
|
@ -208,21 +208,23 @@ class PostTraining(Protocol):
|
|||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/jobs")
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/post-training/job/logs")
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||
async def get_training_job_logstream(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobLogStream: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
def get_training_job_status(
|
||||
async def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
def get_training_job_artifacts(
|
||||
async def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
@ -20,7 +20,7 @@ class MetaReferencePostTrainingImpl:
|
|||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
|
||||
def supervised_fine_tune(
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
model: str,
|
||||
|
@ -28,11 +28,11 @@ class MetaReferencePostTrainingImpl:
|
|||
validation_dataset_id: str,
|
||||
algorithm: FinetuningAlgorithm,
|
||||
algorithm_config: LoraFinetuningConfig,
|
||||
optimizer_config: OptimizerConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
|
||||
# wrapper request to make it easier to pass around (internal only, not exposed to API)
|
||||
request = PostTrainingSFTRequest(
|
||||
job_uuid=job_uuid,
|
||||
|
@ -41,7 +41,6 @@ class MetaReferencePostTrainingImpl:
|
|||
validation_dataset_id=validation_dataset_id,
|
||||
algorithm=algorithm,
|
||||
algorithm_config=algorithm_config,
|
||||
optimizer_config=optimizer_config,
|
||||
training_config=training_config,
|
||||
hyperparam_search_config=hyperparam_search_config,
|
||||
logger_config=logger_config,
|
||||
|
@ -50,14 +49,14 @@ class MetaReferencePostTrainingImpl:
|
|||
recipe = LoraFinetuningSingleDevice(
|
||||
self.config, request, self.datasetio_api
|
||||
)
|
||||
recipe.setup(self.config)
|
||||
recipe.train()
|
||||
await recipe.setup(self.config)
|
||||
await recipe.train()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return PostTrainingJob(job_uuid=job_uuid)
|
||||
|
||||
def preference_optimize(
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: URL,
|
||||
|
@ -71,21 +70,24 @@ class MetaReferencePostTrainingImpl:
|
|||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
# TODO @markchen1015 impelment below APIs
|
||||
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||
|
||||
# sends SSE stream of logs
|
||||
@webmethod(route="/post-training/job/logs")
|
||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
||||
async def get_training_job_logstream(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobLogStream: ...
|
||||
|
||||
@webmethod(route="/post-training/job/status")
|
||||
def get_training_job_status(
|
||||
async def get_training_job_status(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobStatusResponse: ...
|
||||
|
||||
@webmethod(route="/post-training/job/cancel")
|
||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||
|
||||
@webmethod(route="/post-training/job/artifacts")
|
||||
def get_training_job_artifacts(
|
||||
async def get_training_job_artifacts(
|
||||
self, job_uuid: str
|
||||
) -> PostTrainingJobArtifactsResponse: ...
|
||||
|
|
|
@ -7,15 +7,20 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from torch import nn
|
||||
from torchtune import utils as torchtune_utils
|
||||
from torchtune.training.checkpointing._utils import ModelType
|
||||
from llama_stack.apis.post_training import * # noqa
|
||||
from llama_stack.apis.post_training import PostTrainingSFTRequest
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
from llama_stack.providers.inline.post_training.meta_reference import utils
|
||||
from llama_stack.providers.inline.post_training.meta_reference.config import (
|
||||
|
@ -47,14 +52,24 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
|||
|
||||
|
||||
class LoraFinetuningSingleDevice:
|
||||
# This recipe only supports GPU training
|
||||
|
||||
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
||||
# - compile
|
||||
# - activation offloading
|
||||
|
||||
# Resume from checkpoint hasn't been supported yet
|
||||
# Validation hasn't been supported yet
|
||||
|
||||
# TODO @markchen1015 figure out the logging for this training recipe
|
||||
# and make it work with telemetry
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferencePostTrainingConfig,
|
||||
request: PostTrainingSFTRequest,
|
||||
datasetio_api: DatasetIO,
|
||||
) -> None:
|
||||
# to make user config easier, assume the device is 'cuda' only
|
||||
# self._device = utils.get_device(device=cfg.device)
|
||||
# Assume the training only happens on GPU
|
||||
self.config = config
|
||||
self.request = request
|
||||
self._device = torchtune_utils.get_device(device="cuda")
|
||||
|
@ -63,11 +78,30 @@ class LoraFinetuningSingleDevice:
|
|||
)
|
||||
self.model_id = config.model
|
||||
|
||||
# hardcode it for now and see how it works with get_training_job_artifacts
|
||||
self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}"
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
|
||||
self._log_every_n_steps = 1
|
||||
self._log_peak_memory_stats = False
|
||||
paths = [
|
||||
Path(checkpoint_dir / f"consolidated.{ext}")
|
||||
for ext in ["pth", "00.pth"]
|
||||
]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
else:
|
||||
model = resolve_model(self.model_id)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
# TODO @markchen1015 make it work with get_training_job_artifacts
|
||||
self._output_dir = self.checkpoint_dir + "/posting_training/"
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed or 42)
|
||||
self.epochs_run = 0
|
||||
|
@ -75,23 +109,15 @@ class LoraFinetuningSingleDevice:
|
|||
self._shuffle = request.training_config.shuffle
|
||||
self._batch_size = request.training_config.batch_size
|
||||
|
||||
self.checkpoint_dir = (
|
||||
self.config.checkpoint_dir or f"~/.llama/checkpoints/{self.model_id}"
|
||||
)
|
||||
|
||||
# this is important for debugging purpose
|
||||
self.max_steps_per_epoch = request.training_config.max_steps_per_epoch
|
||||
self.global_step = 0
|
||||
|
||||
# not needed in MVP
|
||||
# self._resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
# self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
|
||||
|
||||
self._gradient_accumulation_steps = (
|
||||
request.training_config.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
self._clip_grad_norm = 1.0 # hardcode
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
request.training_config.enable_activation_checkpointing
|
||||
)
|
||||
|
@ -99,12 +125,11 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
self.datasetio_api = datasetio_api
|
||||
|
||||
def load_checkpoint(self):
|
||||
async def load_checkpoint(self):
|
||||
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||
try:
|
||||
# List all files in the given directory
|
||||
files = os.listdir(checkpoint_dir)
|
||||
|
||||
# Filter files that end with .pth
|
||||
pth_files = [file for file in files if file.endswith(".pth")]
|
||||
return pth_files
|
||||
|
@ -115,44 +140,40 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_dir=self.checkpoint_dir,
|
||||
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||
output_dir=self._output_dir,
|
||||
# todo: automatically get this info from model
|
||||
model_type="LLAMA3",
|
||||
model_type=utils.get_checkpointer_model_type(self.model_id),
|
||||
)
|
||||
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||
return checkpoint_dict
|
||||
|
||||
def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
||||
# todo: figure out how does it works with telemetry
|
||||
# self._metric_logger = config.instantiate(cfg.metric_logger)
|
||||
# self._metric_logger.log_config(cfg)
|
||||
async def setup(self, config: MetaReferencePostTrainingConfig) -> None:
|
||||
checkpoint_dict = await self.load_checkpoint()
|
||||
|
||||
checkpoint_dict = self.load_checkpoint()
|
||||
|
||||
# hack to toggle to the low cpu ram version of the reparametrize_as_dtype
|
||||
# hook based on the config.
|
||||
# common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False)
|
||||
|
||||
# set up model
|
||||
self._model = self._setup_model(
|
||||
self._model = await self._setup_model(
|
||||
enable_activation_checkpointing=self._enable_activation_checkpointing,
|
||||
enable_activation_offloading=self._enable_activation_offloading,
|
||||
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
|
||||
lora_weights_state_dict=None,
|
||||
)
|
||||
log.info(f"Model is initialized with precision {self._dtype}.")
|
||||
|
||||
self._tokenizer = self._setup_tokenizer()
|
||||
self._tokenizer = await self._setup_tokenizer()
|
||||
log.info("Tokenizer is initialized from file.")
|
||||
|
||||
self._optimizer = self._setup_optimizer(
|
||||
optimizer_config=self.request.training_config.optimizer
|
||||
self._optimizer = await self._setup_optimizer(
|
||||
optimizer_config=self.request.training_config.optimizer_config
|
||||
)
|
||||
log.info("Optimizer is initialized.")
|
||||
|
||||
self._loss_fn = CEWithChunkedOutputLoss()
|
||||
self._sampler, self._dataloader = self._setup_data(
|
||||
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||
log.info("Loss is initialized.")
|
||||
|
||||
self._sampler, self._dataloader = await self._setup_data(
|
||||
tokenizer=self._tokenizer,
|
||||
shuffle=self._shuffle,
|
||||
batch_size=self._batch_size,
|
||||
)
|
||||
log.info("Dataset and Sampler are initialized.")
|
||||
|
||||
# Number of training steps in each epoch depends on the number of batches produced
|
||||
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||
|
@ -170,18 +191,19 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Learning rate scheduler can only be set up after number of steps
|
||||
# has been computed
|
||||
self._lr_scheduler = self._setup_lr_scheduler(
|
||||
num_warmup_steps=self.request.optimizer_config.num_warmup_steps,
|
||||
self._lr_scheduler = await self._setup_lr_scheduler(
|
||||
num_warmup_steps=self.request.training_config.optimizer_config.num_warmup_steps,
|
||||
num_training_steps=self.total_epochs * self._steps_per_epoch,
|
||||
last_epoch=self.global_step - 1,
|
||||
)
|
||||
log.info("Learning rate scheduler is initialized.")
|
||||
|
||||
# Used to ignore labels for loss computation
|
||||
self.ignore_labels_cache = torch.full(
|
||||
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
||||
)
|
||||
|
||||
def _setup_model(
|
||||
async def _setup_model(
|
||||
self,
|
||||
enable_activation_checkpointing: bool,
|
||||
enable_activation_offloading: bool,
|
||||
|
@ -243,9 +265,8 @@ class LoraFinetuningSingleDevice:
|
|||
lora_missing=lora_missing,
|
||||
lora_unexpected=lora_unexpected,
|
||||
)
|
||||
|
||||
# Validate model adapter params were loaded in with the expected dtype
|
||||
# TODO (rohan-varma): Further validation to ensure the appropriate base params
|
||||
# are NF4 vs bf16 based on the quantization config.
|
||||
training.validate_expected_param_dtype(
|
||||
self.adapter_params.items(), dtype=self._dtype
|
||||
)
|
||||
|
@ -254,22 +275,16 @@ class LoraFinetuningSingleDevice:
|
|||
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||
model, enable_activation_offloading
|
||||
)
|
||||
|
||||
log.info(f"Model is initialized with precision {self._dtype}.")
|
||||
|
||||
# if self._device.type != "cpu":
|
||||
# memory_stats = training.get_memory_stats(device=self._device)
|
||||
# training.log_memory_stats(memory_stats)
|
||||
return model
|
||||
|
||||
def _setup_tokenizer(
|
||||
async def _setup_tokenizer(
|
||||
self,
|
||||
) -> Llama3Tokenizer:
|
||||
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
||||
tokenizer_type = utils.get_tokenizer_type(self.model_id)
|
||||
return tokenizer_type(path=tokenizer_path)
|
||||
|
||||
def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
||||
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
||||
optimizer = torch.optim.AdamW(
|
||||
params=self._model.parameters(),
|
||||
lr=optimizer_config.lr,
|
||||
|
@ -277,11 +292,9 @@ class LoraFinetuningSingleDevice:
|
|||
eps=1e-8,
|
||||
weight_decay=0.1,
|
||||
)
|
||||
|
||||
log.info("Optimizer and loss are initialized.")
|
||||
return optimizer
|
||||
|
||||
def _setup_data(
|
||||
async def _setup_data(
|
||||
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
async def fetch_rows():
|
||||
|
@ -290,10 +303,11 @@ class LoraFinetuningSingleDevice:
|
|||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
# Run the async function in an event loop
|
||||
all_rows = asyncio.run(fetch_rows())
|
||||
all_rows = await fetch_rows()
|
||||
rows = all_rows.rows
|
||||
|
||||
# Curretly only support instruct dataset
|
||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||
ds = SFTDataset(
|
||||
rows, message_transform=InputOutputToMessages(), model_transform=tokenizer
|
||||
)
|
||||
|
@ -320,11 +334,9 @@ class LoraFinetuningSingleDevice:
|
|||
),
|
||||
)
|
||||
|
||||
log.info("Dataset and Sampler are initialized.")
|
||||
|
||||
return sampler, dataloader
|
||||
|
||||
def _setup_lr_scheduler(
|
||||
async def _setup_lr_scheduler(
|
||||
self,
|
||||
num_warmup_steps: int,
|
||||
num_training_steps: int,
|
||||
|
@ -332,33 +344,19 @@ class LoraFinetuningSingleDevice:
|
|||
) -> Optimizer:
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
self._optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
last_epoch=last_epoch,
|
||||
)
|
||||
|
||||
log.info("Learning rate scheduler is initialized.")
|
||||
return lr_scheduler
|
||||
|
||||
def save_checkpoint(self, epoch: int) -> None:
|
||||
"""
|
||||
Checkpoint the state of the recipe. The constructed checkpoint state dict
|
||||
contains the following information:
|
||||
- Merged weights with key MODEL_KEY
|
||||
- Adapter weights with key ADAPTER_KEY
|
||||
- Relevant recipe state if training is not complete
|
||||
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
|
||||
|
||||
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
|
||||
"""
|
||||
async def save_checkpoint(self, epoch: int) -> None:
|
||||
ckpt_dict = {}
|
||||
|
||||
intermediate_checkpoint = epoch + 1 < self.total_epochs
|
||||
|
||||
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
|
||||
|
||||
# Construct the full state dict with LoRA weights merged into base LLM weights
|
||||
|
||||
# Move to CPU to avoid a copy on GPU
|
||||
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||
|
||||
|
@ -385,10 +383,9 @@ class LoraFinetuningSingleDevice:
|
|||
self._checkpointer.save_checkpoint(
|
||||
ckpt_dict,
|
||||
epoch=epoch,
|
||||
intermediate_checkpoint=intermediate_checkpoint,
|
||||
)
|
||||
|
||||
def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
|
@ -412,16 +409,10 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
return loss
|
||||
|
||||
def train(self) -> None:
|
||||
async def train(self) -> None:
|
||||
"""
|
||||
The core training loop.
|
||||
"""
|
||||
|
||||
# if self._compile:
|
||||
# log.info(
|
||||
# "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
|
||||
# )
|
||||
|
||||
# Initialize tokens count and running loss (for grad accumulation)
|
||||
# t0 = time.perf_counter()
|
||||
running_loss = 0
|
||||
|
@ -433,7 +424,6 @@ class LoraFinetuningSingleDevice:
|
|||
# in case shuffle is True
|
||||
self._sampler.set_epoch(curr_epoch)
|
||||
|
||||
# pbar = tqdm(total=self._steps_per_epoch)
|
||||
for idx, batch in enumerate(self._dataloader):
|
||||
if (
|
||||
self.max_steps_per_epoch is not None
|
||||
|
@ -442,14 +432,6 @@ class LoraFinetuningSingleDevice:
|
|||
):
|
||||
break
|
||||
|
||||
# Start tracking CUDA memory for active steps for just the first epoch
|
||||
# if (
|
||||
# curr_epoch == 0
|
||||
# and self.profiler_profile_memory
|
||||
# and idx == self.profiler_wait_steps + self.profiler_warmup_steps
|
||||
# ):
|
||||
# torch.cuda.memory._record_memory_history()
|
||||
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
# Calculate the number of unmasked tokens in the current batch
|
||||
|
@ -461,14 +443,14 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
# Loss is normalized by default so we multiply by the number of tokens
|
||||
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||
current_loss = self._loss_step(batch) * current_num_tokens
|
||||
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||
running_loss += current_loss
|
||||
current_loss.backward()
|
||||
|
||||
# Step with optimizer
|
||||
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||
training.scale_grads(self._model, 1 / num_tokens)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self._model.parameters(),
|
||||
max_norm=float(self._clip_grad_norm),
|
||||
)
|
||||
|
@ -478,58 +460,10 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the number of steps when the weights are updated
|
||||
self.global_step += 1
|
||||
|
||||
# loss_to_log = running_loss.item() / num_tokens
|
||||
# pbar.update(1)
|
||||
# pbar.set_description(
|
||||
# f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||
# )
|
||||
|
||||
# Log per-step metrics
|
||||
# if self.global_step % self._log_every_n_steps == 0:
|
||||
# time_per_step = time.perf_counter() - t0
|
||||
# log_dict = {
|
||||
# "loss": loss_to_log,
|
||||
# "lr": self._optimizer.param_groups[0]["lr"],
|
||||
# "tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||
# }
|
||||
# if self._device.type == "cuda" and self._log_peak_memory_stats:
|
||||
# log_dict.update(
|
||||
# training.get_memory_stats(device=self._device)
|
||||
# )
|
||||
# if self._clip_grad_norm is not None:
|
||||
# log_dict.update({"grad_norm": grad_norm})
|
||||
# self._metric_logger.log_dict(
|
||||
# log_dict,
|
||||
# step=self.global_step,
|
||||
# )
|
||||
|
||||
# Reset running stats for the next step
|
||||
running_loss = 0
|
||||
num_tokens = 0
|
||||
# t0 = time.perf_counter()
|
||||
|
||||
# Stop tracking CUDA memory now that active steps are complete
|
||||
# if (
|
||||
# curr_epoch == 0
|
||||
# and self.profiler_profile_memory
|
||||
# and idx
|
||||
# == self.profiler_wait_steps
|
||||
# + self.profiler_warmup_steps
|
||||
# + self.profiler_active_steps
|
||||
# ):
|
||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
||||
|
||||
# Step the profiler
|
||||
# Note we are stepping each batch, which might not include optimizer step in the trace
|
||||
# if the schedule cycle doesn't align with gradient accumulation.
|
||||
# prof.step()
|
||||
|
||||
self.epochs_run += 1
|
||||
# start_save_checkpoint = time.perf_counter()
|
||||
log.info("Starting checkpoint save...")
|
||||
self.save_checkpoint(epoch=curr_epoch)
|
||||
# log.info(
|
||||
# "Checkpoint saved in {:.2f} seconds.".format(
|
||||
# time.perf_counter() - start_save_checkpoint
|
||||
# )
|
||||
# )
|
||||
await self.save_checkpoint(epoch=curr_epoch)
|
||||
|
|
|
@ -16,15 +16,22 @@ import torch
|
|||
from llama_models.sku_list import resolve_model
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
LORA_MODEL_TYPES: Dict[str, Any] = {
|
||||
"Llama3.2-3B-Instruct": lora_llama3_2_3b,
|
||||
"Llama-3-8B-Instruct": lora_llama3_8b,
|
||||
}
|
||||
|
||||
TOKENIZER_TYPES: Dict[str, Any] = {
|
||||
"Llama3.2-3B-Instruct": llama3_tokenizer,
|
||||
"Llama-3-8B-Instruct": llama3_tokenizer,
|
||||
}
|
||||
|
||||
CHECKPOINT_MODEL_TYPES: Dict[str, str] = {
|
||||
"Llama3.2-3B-Instruct": "LLAMA3_2",
|
||||
}
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
@ -41,3 +48,10 @@ def get_tokenizer_type(
|
|||
) -> BuildTokenizerCallable:
|
||||
model = resolve_model(model_id)
|
||||
return TOKENIZER_TYPES[model.core_model_id.value]
|
||||
|
||||
|
||||
def get_checkpointer_model_type(
|
||||
model_id: str,
|
||||
) -> str:
|
||||
model = resolve_model(model_id)
|
||||
return CHECKPOINT_MODEL_TYPES[model.core_model_id.value]
|
||||
|
|
|
@ -71,7 +71,7 @@ datasets:
|
|||
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
||||
metadata:
|
||||
path: tatsu-lab/alpaca
|
||||
name: post_training_alpaca
|
||||
name:
|
||||
split: train
|
||||
dataset_schema:
|
||||
instruction:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue