mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-06 10:42:39 +00:00
refine
This commit is contained in:
parent
09e9445a11
commit
a57f46e363
2 changed files with 20 additions and 42 deletions
|
@ -17,19 +17,11 @@ from torchtune.models import convert_weights
|
||||||
from torchtune.training.checkpointing._utils import (
|
from torchtune.training.checkpointing._utils import (
|
||||||
ADAPTER_CONFIG_FNAME,
|
ADAPTER_CONFIG_FNAME,
|
||||||
ADAPTER_MODEL_FNAME,
|
ADAPTER_MODEL_FNAME,
|
||||||
check_outdir_not_in_ckptdir,
|
|
||||||
copy_files,
|
copy_files,
|
||||||
get_adapter_checkpoint_path,
|
|
||||||
get_model_checkpoint_path,
|
|
||||||
get_recipe_checkpoint_path,
|
|
||||||
ModelType,
|
ModelType,
|
||||||
RECIPE_STATE_DIRNAME,
|
|
||||||
REPO_ID_FNAME,
|
REPO_ID_FNAME,
|
||||||
safe_torch_load,
|
safe_torch_load,
|
||||||
SAFETENSOR_INDEX_FNAME,
|
|
||||||
SHARD_FNAME,
|
|
||||||
SUFFIXES_TO_NOT_COPY,
|
SUFFIXES_TO_NOT_COPY,
|
||||||
TORCH_INDEX_FNAME,
|
|
||||||
)
|
)
|
||||||
from torchtune.utils._logging import get_logger
|
from torchtune.utils._logging import get_logger
|
||||||
|
|
||||||
|
@ -176,8 +168,6 @@ class TorchtuneCheckpointer:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
)
|
)
|
||||||
|
|
||||||
print("model_file_path", str(model_file_path))
|
|
||||||
elif checkpoint_format == "hf":
|
elif checkpoint_format == "hf":
|
||||||
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
# Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now
|
||||||
|
|
||||||
|
@ -238,7 +228,7 @@ class TorchtuneCheckpointer:
|
||||||
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
f"{os.path.getsize(output_path) / 1024**3:.2f} GiB "
|
||||||
f"saved to {output_path}"
|
f"saved to {output_path}"
|
||||||
)
|
)
|
||||||
elif adapter_only:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
)
|
)
|
||||||
|
@ -269,18 +259,6 @@ class TorchtuneCheckpointer:
|
||||||
model_file_path,
|
model_file_path,
|
||||||
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
ignore_suffixes=SUFFIXES_TO_NOT_COPY,
|
||||||
)
|
)
|
||||||
logger.info("Saving final epoch checkpoint.")
|
|
||||||
if adapter_only:
|
|
||||||
logger.info(
|
|
||||||
"Please note that you have set adapter_only=True, so only adapter weights will be saved."
|
|
||||||
"You need to merge the adapter weights into your base model for further use. "
|
|
||||||
f"See {self.__class__.__name__}.save_checkpoint for more details."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"The full model checkpoint, including all weights and configurations, has been saved successfully."
|
|
||||||
"You can now use this checkpoint for further training or inference."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported checkpoint format: {format}")
|
raise ValueError(f"Unsupported checkpoint format: {format}")
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,24 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torchtune import modules, training, utils as torchtune_utils
|
||||||
|
from torchtune.data import padded_collate_sft
|
||||||
|
|
||||||
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
|
from torchtune.modules.peft import (
|
||||||
|
get_adapter_params,
|
||||||
|
get_adapter_state_dict,
|
||||||
|
get_lora_module_names,
|
||||||
|
get_merged_lora_ckpt,
|
||||||
|
set_trainable_params,
|
||||||
|
validate_missing_and_unexpected_for_lora,
|
||||||
|
)
|
||||||
|
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||||
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
@ -42,24 +60,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
TorchtunePostTrainingConfig,
|
TorchtunePostTrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
from torch import nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.data import DataLoader, DistributedSampler
|
|
||||||
from torchtune import modules, training, utils as torchtune_utils
|
|
||||||
from torchtune.data import padded_collate_sft
|
|
||||||
|
|
||||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
|
||||||
from torchtune.modules.peft import (
|
|
||||||
get_adapter_params,
|
|
||||||
get_adapter_state_dict,
|
|
||||||
get_lora_module_names,
|
|
||||||
get_merged_lora_ckpt,
|
|
||||||
set_trainable_params,
|
|
||||||
validate_missing_and_unexpected_for_lora,
|
|
||||||
)
|
|
||||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
|
||||||
from torchtune.training.metric_logging import DiskLogger
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -490,7 +490,7 @@ class LoraFinetuningSingleDevice:
|
||||||
# Update the sampler to ensure data is correctly shuffled across epochs
|
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||||
# in case shuffle is True
|
# in case shuffle is True
|
||||||
metric_logger = DiskLogger(
|
metric_logger = DiskLogger(
|
||||||
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}/log"
|
||||||
)
|
)
|
||||||
self._training_sampler.set_epoch(curr_epoch)
|
self._training_sampler.set_epoch(curr_epoch)
|
||||||
loss_to_log = 0.0
|
loss_to_log = 0.0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue