This commit is contained in:
Botao Chen 2025-01-03 16:27:32 -08:00
parent 280581a4a3
commit bbe190a085
3 changed files with 23 additions and 21 deletions

View file

@ -68,6 +68,7 @@ class TrainingConfig(BaseModel):
n_epochs: int n_epochs: int
max_steps_per_epoch: int max_steps_per_epoch: int
gradient_accumulation_steps: int gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig data_config: DataConfig
optimizer_config: OptimizerConfig optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None efficiency_config: Optional[EfficiencyConfig] = None

View file

@ -52,9 +52,9 @@ class ModelConfig(BaseModel):
class DatasetSchema(BaseModel): class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]] alpaca: List[Dict[str, ParamType]]
instruct: Dict[str, ParamType] instruct: List[Dict[str, ParamType]]
chat_sharegpt: Dict[str, ParamType] chat_sharegpt: List[Dict[str, ParamType]]
chat_openai: Dict[str, ParamType] chat_openai: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = { MODEL_CONFIGS: Dict[str, ModelConfig] = {
@ -96,16 +96,22 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema(
ColumnName.output.value: StringType(), ColumnName.output.value: StringType(),
}, },
], ],
instruct={ instruct=[
ColumnName.input.value: StringType(), {
ColumnName.output.value: StringType(), ColumnName.input.value: StringType(),
}, ColumnName.output.value: StringType(),
chat_sharegpt={ }
ColumnName.conversations.value: StringType(), ],
}, chat_sharegpt=[
chat_openai={ {
ColumnName.messages.value: StringType(), ColumnName.conversations.value: StringType(),
}, }
],
chat_openai=[
{
ColumnName.messages.value: StringType(),
}
],
) )
BuildLoraModelCallable = Callable[..., torch.nn.Module] BuildLoraModelCallable = Callable[..., torch.nn.Module]

View file

@ -126,8 +126,7 @@ class LoraFinetuningSingleDevice:
self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
# self.seed = training.set_seed(seed=config.torch_seed) self.seed = training.set_seed(seed=config.torch_seed)
self.seed = 42
self.epochs_run = 0 self.epochs_run = 0
self.total_epochs = training_config.n_epochs self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format self._data_format = training_config.data_config.data_format
@ -140,6 +139,7 @@ class LoraFinetuningSingleDevice:
self.global_step = 0 self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0 self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = ( self._enable_activation_checkpointing = (
@ -366,7 +366,6 @@ class LoraFinetuningSingleDevice:
column_map=self._column_map, column_map=self._column_map,
) )
data_transform = await utils.get_data_transform(self._data_format) data_transform = await utils.get_data_transform(self._data_format)
print("data_transform", data_transform.__name__)
ds = SFTDataset( ds = SFTDataset(
rows, rows,
message_transform=data_transform( message_transform=data_transform(
@ -450,10 +449,6 @@ class LoraFinetuningSingleDevice:
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model # Shape [b, s], needed for the loss not the model
# print("tokens", batch["tokens"])
torch.save(batch["tokens"], "/home/markchen1015/new_alpaca_tokens.pth")
# print("labels", batch["labels"])
torch.save(batch["labels"], "/home/markchen1015/new_alpaca_labels.pth")
labels = batch.pop("labels") labels = batch.pop("labels")
# run model # run model
with self.activations_handling_ctx: with self.activations_handling_ctx:
@ -595,7 +590,7 @@ class LoraFinetuningSingleDevice:
log.info("Starting validation...") log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader)) pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader): for idx, batch in enumerate(self._validation_dataloader):
if idx == 10: if idx == self.max_validation_steps:
break break
torchtune_utils.batch_to_device(batch, self._device) torchtune_utils.batch_to_device(batch, self._device)