temp commit

This commit is contained in:
Botao Chen 2025-01-03 13:28:39 -08:00
parent 96d8375663
commit 346a6c658d
3 changed files with 73 additions and 9 deletions

View file

@ -27,14 +27,24 @@ class OptimizerType(Enum):
sgd = "sgd"
@json_schema_type
class DatasetFormat(Enum):
alpaca = "alpaca"
instruct = "instruct"
chat_sharegpt = "chat_sharegpt"
chat_openai = "chat_openai"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
data_format: DatasetFormat
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
column_map: Optional[Dict[str, str]] = None
@json_schema_type
@ -58,7 +68,6 @@ class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None

View file

@ -11,19 +11,28 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional
import torch
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import DatasetFormat
from pydantic import BaseModel
from torchtune.data._messages import (
AlpacaToMessages,
InputOutputToMessages,
OpenAIToMessages,
ShareGPTToMessages,
)
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b
from torchtune.modules.transforms import Transform
class ColumnName(Enum):
@ -31,6 +40,8 @@ class ColumnName(Enum):
input = "input"
output = "output"
text = "text"
conversations = "conversations"
messages = "messages"
class ModelConfig(BaseModel):
@ -41,6 +52,9 @@ class ModelConfig(BaseModel):
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
instruct: Dict[str, ParamType]
chat_sharegpt: Dict[str, ParamType]
chat_openai: Dict[str, ParamType]
MODEL_CONFIGS: Dict[str, ModelConfig] = {
@ -56,6 +70,13 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
),
}
DATA_FORMATS: Dict[str, Transform] = {
"alpaca": AlpacaToMessages,
"instruct": InputOutputToMessages,
"chat_sharegpt": ShareGPTToMessages,
"chat_openai": OpenAIToMessages,
}
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
@ -74,7 +95,17 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema(
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
],
instruct={
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
chat_sharegpt={
ColumnName.conversations.value: StringType(),
},
chat_openai={
ColumnName.messages.value: StringType(),
},
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]
@ -122,10 +153,15 @@ async def get_checkpointer_model_type(
return model_config.checkpoint_type
async def get_data_transform(data_format: DatasetFormat) -> Transform:
return DATA_FORMATS[data_format.value]
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
column_map: Optional[Dict[str, str]] = None,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
@ -134,7 +170,21 @@ async def validate_input_dataset_schema(
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
dataset_schema = {}
if column_map:
for old_col_name in dataset_def.dataset_schema.keys():
if old_col_name in column_map.values():
new_col_name = next(
k for k, v in column_map.items() if v == old_col_name
)
dataset_schema[new_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema[old_col_name] = dataset_def.dataset_schema[old_col_name]
else:
dataset_schema = dataset_def.dataset_schema
if dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)

View file

@ -42,7 +42,7 @@ from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.data import padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
@ -129,15 +129,16 @@ class LoraFinetuningSingleDevice:
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
self._column_map = training_config.data_config.column_map
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
@ -360,11 +361,15 @@ class LoraFinetuningSingleDevice:
await utils.validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
dataset_type=self._data_format.value,
column_map=self._column_map,
)
data_transform = await utils.get_data_transform(self._data_format)
ds = SFTDataset(
rows,
message_transform=AlpacaToMessages(train_on_input=False),
message_transform=data_transform(
train_on_input=False, column_map=self._column_map
),
model_transform=tokenizer,
)
@ -584,7 +589,7 @@ class LoraFinetuningSingleDevice:
log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader):
if idx == self.max_validation_steps:
if idx == 10:
break
torchtune_utils.batch_to_device(batch, self._device)