Add init files to post training folders (#711)

add init files to post training folders to make pkg build pick up those
files

## Test
WIP colab notebook
https://colab.research.google.com/drive/1K4Q2wZq232_Bpy2ud4zL9aRxvCWAwyQs?usp=sharing
to sharecase the post training APIs
This commit is contained in:
Botao Chen 2025-01-13 20:19:18 -08:00 committed by GitHub
parent f320eede2b
commit 747683a8a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 41 additions and 20 deletions

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -16,8 +16,6 @@ from typing import Any, Callable, Dict, List
import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from pydantic import BaseModel
@ -26,6 +24,9 @@ from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
class ColumnName(Enum):
instruction = "instruction"

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -14,6 +14,24 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
@ -38,24 +56,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
log = logging.getLogger(__name__)