From 562d5228de38d89dc65f8955ea60ffa36a4ec300 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Thu, 2 Jan 2025 20:23:55 -0800 Subject: [PATCH] pre-commit --- .../post_training/torchtune/common/utils.py | 5 +-- .../recipes/lora_finetuning_single_device.py | 36 +++++++++---------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index a5279cdbe..042451ed1 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -16,8 +16,6 @@ from typing import Any, Callable, Dict, List import torch from llama_models.datatypes import Model from llama_models.sku_list import resolve_model -from llama_stack.apis.common.type_system import ParamType, StringType -from llama_stack.apis.datasets import Datasets from pydantic import BaseModel @@ -25,6 +23,9 @@ from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_2 import lora_llama3_2_3b +from llama_stack.apis.common.type_system import ParamType, StringType +from llama_stack.apis.datasets import Datasets + class ColumnName(Enum): instruction = "instruction" diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 1b6c508a7..7df442fe8 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -14,6 +14,24 @@ from typing import Any, Dict, List, Optional, Tuple import torch from llama_models.sku_list import resolve_model +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import modules, training, utils as torchtune_utils +from torchtune.data import AlpacaToMessages, padded_collate_sft + +from torchtune.modules.loss import CEWithChunkedOutputLoss +from torchtune.modules.peft import ( + get_adapter_params, + get_adapter_state_dict, + get_lora_module_names, + get_merged_lora_ckpt, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup +from torchtune.training.metric_logging import DiskLogger +from tqdm import tqdm from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO @@ -38,24 +56,6 @@ from llama_stack.providers.inline.post_training.torchtune.config import ( TorchtunePostTrainingConfig, ) from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset -from torch import nn -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training, utils as torchtune_utils -from torchtune.data import AlpacaToMessages, padded_collate_sft - -from torchtune.modules.loss import CEWithChunkedOutputLoss -from torchtune.modules.peft import ( - get_adapter_params, - get_adapter_state_dict, - get_lora_module_names, - get_merged_lora_ckpt, - set_trainable_params, - validate_missing_and_unexpected_for_lora, -) -from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup -from torchtune.training.metric_logging import DiskLogger -from tqdm import tqdm log = logging.getLogger(__name__)