temp commit

This commit is contained in:
Botao Chen 2025-01-03 13:28:39 -08:00
parent 96d8375663
commit 346a6c658d
3 changed files with 73 additions and 9 deletions

View file

@ -27,14 +27,24 @@ class OptimizerType(Enum):
sgd = "sgd" sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
alpaca = "alpaca"
instruct = "instruct"
chat_sharegpt = "chat_sharegpt"
chat_openai = "chat_openai"
@json_schema_type @json_schema_type
class DataConfig(BaseModel): class DataConfig(BaseModel):
dataset_id: str dataset_id: str
batch_size: int batch_size: int
shuffle: bool shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False packed: Optional[bool] = False
train_on_input: Optional[bool] = False train_on_input: Optional[bool] = False
column_map: Optional[Dict[str, str]] = None
@json_schema_type @json_schema_type
@ -58,7 +68,6 @@ class TrainingConfig(BaseModel):
n_epochs: int n_epochs: int
max_steps_per_epoch: int max_steps_per_epoch: int
gradient_accumulation_steps: int gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig data_config: DataConfig
optimizer_config: OptimizerConfig optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None efficiency_config: Optional[EfficiencyConfig] = None

View file

@ -11,19 +11,28 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List, Optional
import torch import torch
from llama_models.datatypes import Model from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import DatasetFormat
from pydantic import BaseModel from pydantic import BaseModel
from torchtune.data._messages import (
AlpacaToMessages,
InputOutputToMessages,
OpenAIToMessages,
ShareGPTToMessages,
)
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
class ColumnName(Enum): class ColumnName(Enum):
@ -31,6 +40,8 @@ class ColumnName(Enum):
input = "input" input = "input"
output = "output" output = "output"
text = "text" text = "text"
conversations = "conversations"
messages = "messages"
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
@ -41,6 +52,9 @@ class ModelConfig(BaseModel):
class DatasetSchema(BaseModel): class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]] alpaca: List[Dict[str, ParamType]]
instruct: Dict[str, ParamType]
chat_sharegpt: Dict[str, ParamType]
chat_openai: Dict[str, ParamType]
MODEL_CONFIGS: Dict[str, ModelConfig] = { MODEL_CONFIGS: Dict[str, ModelConfig] = {
@ -56,6 +70,13 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
), ),
} }
DATA_FORMATS: Dict[str, Transform] = {
"alpaca": AlpacaToMessages,
"instruct": InputOutputToMessages,
"chat_sharegpt": ShareGPTToMessages,
"chat_openai": OpenAIToMessages,
}
EXPECTED_DATASET_SCHEMA = DatasetSchema( EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[ alpaca=[
@ -74,7 +95,17 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema(
ColumnName.instruction.value: StringType(), ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(), ColumnName.output.value: StringType(),
}, },
] ],
instruct={
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
chat_sharegpt={
ColumnName.conversations.value: StringType(),
},
chat_openai={
ColumnName.messages.value: StringType(),
},
) )
BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildLoraModelCallable = Callable[..., torch.nn.Module]
@ -122,10 +153,15 @@ async def get_checkpointer_model_type(
return model_config.checkpoint_type return model_config.checkpoint_type
async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]
async def validate_input_dataset_schema( async def validate_input_dataset_schema(
datasets_api: Datasets, datasets_api: Datasets,
dataset_id: str, dataset_id: str,
dataset_type: str, dataset_type: str,
column_map: Optional[Dict[str, str]] = None,
) -> None: ) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
@ -134,7 +170,21 @@ async def validate_input_dataset_schema(
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type): if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.") raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type): dataset_schema = {}
if column_map:
for old_col_name in dataset_def.dataset_schema.keys():
if old_col_name in column_map.values():
new_col_name = next(
k for k, v in column_map.items() if v == old_col_name
)
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema = dataset_def.dataset_schema
if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError( raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}" f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
) )

View file

@ -42,7 +42,7 @@ from torch import nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import ( from torchtune.modules.peft import (
@ -129,15 +129,16 @@ class LoraFinetuningSingleDevice:
self.seed = training.set_seed(seed=config.torch_seed) self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0 self.epochs_run = 0
self.total_epochs = training_config.n_epochs self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size self._batch_size = training_config.data_config.batch_size
self._column_map = training_config.data_config.column_map
# this is important for debugging purpose # this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch self.max_steps_per_epoch = training_config.max_steps_per_epoch
self.global_step = 0 self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0 self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = ( self._enable_activation_checkpointing = (
@ -360,11 +361,15 @@ class LoraFinetuningSingleDevice:
await utils.validate_input_dataset_schema( await utils.validate_input_dataset_schema(
datasets_api=self.datasets_api, datasets_api=self.datasets_api,
dataset_id=dataset_id, dataset_id=dataset_id,
dataset_type="alpaca", dataset_type=self._data_format.value,
column_map=self._column_map,
) )
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset( ds = SFTDataset(
rows, rows,
message_transform=AlpacaToMessages(train_on_input=False), message_transform=data_transform(
train_on_input=False, column_map=self._column_map
),
model_transform=tokenizer, model_transform=tokenizer,
) )
@ -584,7 +589,7 @@ class LoraFinetuningSingleDevice:
log.info("Starting validation...") log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader)) pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader): for idx, batch in enumerate(self._validation_dataloader):
if idx == self.max_validation_steps: if idx == 10:
break break
torchtune_utils.batch_to_device(batch, self._device) torchtune_utils.batch_to_device(batch, self._device)