mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-03 19:22:16 +00:00
Merge remote-tracking branch 'origin/main' into support_more_data_format
This commit is contained in:
commit
a3b1c3438b
171 changed files with 14529 additions and 5612 deletions
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -16,8 +16,6 @@ import torch
|
|||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
|
||||
from pydantic import BaseModel
|
||||
from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages
|
||||
|
||||
|
|
@ -27,6 +25,8 @@ from torchtune.models.llama3_1 import lora_llama3_1_8b
|
|||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
from torchtune.modules.transforms import Transform
|
||||
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -14,6 +14,24 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
from llama_models.sku_list import resolve_model
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.common.training_types import PostTrainingMetric
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
|
|
@ -41,24 +59,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import (
|
|||
TorchtunePostTrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader, DistributedSampler
|
||||
from torchtune import modules, training, utils as torchtune_utils
|
||||
from torchtune.data import padded_collate_sft
|
||||
|
||||
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||
from torchtune.modules.peft import (
|
||||
get_adapter_params,
|
||||
get_adapter_state_dict,
|
||||
get_lora_module_names,
|
||||
get_merged_lora_ckpt,
|
||||
set_trainable_params,
|
||||
validate_missing_and_unexpected_for_lora,
|
||||
)
|
||||
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||
from torchtune.training.metric_logging import DiskLogger
|
||||
from tqdm import tqdm
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue