Merge branch 'main' into HuggingfacePostTrainingConfig-branch

This commit is contained in:
Sarthak Deshpande 2025-08-25 11:59:15 +05:30 committed by GitHub
commit d0d737680f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
193 changed files with 7108 additions and 881 deletions

View file

@ -6,7 +6,6 @@
import gc
import json
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -28,6 +27,7 @@ from llama_stack.apis.post_training import (
LoraFinetuningConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -44,7 +44,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
class HFFinetuningSingleDevice:

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import gc
import logging
import multiprocessing
from pathlib import Path
from typing import Any
@ -24,6 +23,7 @@ from llama_stack.apis.post_training import (
DPOAlignmentConfig,
TrainingConfig,
)
from llama_stack.log import get_logger
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from ..config import HuggingFacePostTrainingConfig
@ -40,7 +40,7 @@ from ..utils import (
split_dataset,
)
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
class HFDPOAlignmentSingleDevice:

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import signal
import sys
@ -19,10 +18,11 @@ from transformers import AutoConfig, AutoModelForCausalLM
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.post_training import Checkpoint, TrainingConfig
from llama_stack.log import get_logger
from .config import HuggingFacePostTrainingConfig
logger = logging.getLogger(__name__)
logger = get_logger(name=__name__, category="post_training")
def setup_environment():

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import time
from datetime import UTC, datetime
@ -19,6 +18,7 @@ from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune import utils as torchtune_utils
from torchtune.data import padded_collate_sft
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
@ -45,6 +45,7 @@ from llama_stack.apis.post_training import (
)
from llama_stack.core.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.core.utils.model_utils import model_local_dir
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.post_training.common.utils import evacuate_model_from_device
from llama_stack.providers.inline.post_training.torchtune.common import utils
@ -56,9 +57,7 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
log = logging.getLogger(__name__)
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
log = get_logger(name=__name__, category="post_training")
class LoraFinetuningSingleDevice: