mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
refine
This commit is contained in:
parent
09e9445a11
commit
a57f46e363
2 changed files with 20 additions and 42 deletions
|
@ -17,19 +17,11 @@ from torchtune.models import convert_weights
|
|||
from torchtune.training.checkpointing._utils import (
|
||||
ADAPTER_CONFIG_FNAME,
|
||||
ADAPTER_MODEL_FNAME,
|
||||
check_outdir_not_in_ckptdir,
|
||||
copy_files,
|
||||
get_adapter_checkpoint_path,
|
||||
get_model_checkpoint_path,
|
||||
get_recipe_checkpoint_path,
|
||||
ModelType,
|
||||
RECIPE_STATE_DIRNAME,
|
||||
REPO_ID_FNAME,
|
||||
safe_torch_load,
|
||||
SAFETENSOR_INDEX_FNAME,
|
||||
SHARD_FNAME,
|
||||
SUFFIXES_TO_NOT_COPY,
|
||||
TORCH_INDEX_FNAME,
|
||||
)
|
||||
from torchtune.utils._logging import get_logger
|
||||
|
||||
|
@ -176,8 +168,6 @@ class TorchtuneCheckpointer:
|
|||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
||||
print("model_file_path", str(model_file_path))
|
||||
elif checkpoint_format == "hf":
|
||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||
|
||||
|
@ -238,7 +228,7 @@ class TorchtuneCheckpointer:
|
|||
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||
f"saved to {output_path}"
|
||||
)
|
||||
elif adapter_only:
|
||||
else:
|
||||
raise ValueError(
|
||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||
)
|
||||
|
@ -269,18 +259,6 @@ class TorchtuneCheckpointer:
|
|||
model_file_path,
|
||||
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||
)
|
||||
logger.info("Saving final epoch checkpoint.")
|
||||
if adapter_only:
|
||||
logger.info(
|
||||
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
|
||||
"You need to merge the adapter weights into your base model for further use. "
|
||||
f"See {self.__class__.__name__}.save_checkpoint for more details."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"The full model checkpoint, including all weights and configurations, has been saved successfully."
|
||||
"You can now use this checkpoint for further training or inference."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported checkpoint format: {format}")
|
||||
|
||||
|
|
|
@ -15,6 +15,24 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
@ -42,24 +60,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -490,7 +490,7 @@ class LoraFinetuningSingleDevice:
|
|||
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||
# in case shuffle is True
|
||||
metric_logger = DiskLogger(
|
||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
|
||||
)
|
||||
self._training_sampler.set_epoch(curr_epoch)
|
||||
loss_to_log = 0.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue