This commit is contained in:
Botao Chen 2025-01-03 16:27:32 -08:00
parent 280581a4a3
commit bbe190a085
3 changed files with 23 additions and 21 deletions

View file

@ -68,6 +68,7 @@ class TrainingConfig(BaseModel):
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
max_validation_steps: int
data_config: DataConfig
optimizer_config: OptimizerConfig
efficiency_config: Optional[EfficiencyConfig] = None

View file

@ -52,9 +52,9 @@ class ModelConfig(BaseModel):
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
instruct: Dict[str, ParamType]
chat_sharegpt: Dict[str, ParamType]
chat_openai: Dict[str, ParamType]
instruct: List[Dict[str, ParamType]]
chat_sharegpt: List[Dict[str, ParamType]]
chat_openai: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = {
@ -96,16 +96,22 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema(
ColumnName.output.value: StringType(),
},
],
instruct={
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
chat_sharegpt={
ColumnName.conversations.value: StringType(),
},
chat_openai={
ColumnName.messages.value: StringType(),
},
instruct=[
{
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
}
],
chat_sharegpt=[
{
ColumnName.conversations.value: StringType(),
}
],
chat_openai=[
{
ColumnName.messages.value: StringType(),
}
],
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]

View file

@ -126,8 +126,7 @@ class LoraFinetuningSingleDevice:
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
# self.seed = training.set_seed(seed=config.torch_seed)
self.seed = 42
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._data_format = training_config.data_config.data_format
@ -140,6 +139,7 @@ class LoraFinetuningSingleDevice:
self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
@ -366,7 +366,6 @@ class LoraFinetuningSingleDevice:
column_map=self._column_map,
)
data_transform = await utils.get_data_transform(self._data_format)
print("data_transform", data_transform.__name__)
ds = SFTDataset(
rows,
message_transform=data_transform(
@ -450,10 +449,6 @@ class LoraFinetuningSingleDevice:
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
# print("tokens", batch["tokens"])
torch.save(batch["tokens"], "/home/markchen1015/new_alpaca_tokens.pth")
# print("labels", batch["labels"])
torch.save(batch["labels"], "/home/markchen1015/new_alpaca_labels.pth")
labels = batch.pop("labels")
# run model
with self.activations_handling_ctx:
@ -595,7 +590,7 @@ class LoraFinetuningSingleDevice:
log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader):
if idx == 10:
if idx == self.max_validation_steps:
break
torchtune_utils.batch_to_device(batch, self._device)