mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
refine
This commit is contained in:
parent
280581a4a3
commit
bbe190a085
3 changed files with 23 additions and 21 deletions
|
@ -68,6 +68,7 @@ class TrainingConfig(BaseModel):
|
|||
n_epochs: int
|
||||
max_steps_per_epoch: int
|
||||
gradient_accumulation_steps: int
|
||||
max_validation_steps: int
|
||||
data_config: DataConfig
|
||||
optimizer_config: OptimizerConfig
|
||||
efficiency_config: Optional[EfficiencyConfig] = None
|
||||
|
|
|
@ -52,9 +52,9 @@ class ModelConfig(BaseModel):
|
|||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
instruct: Dict[str, ParamType]
|
||||
chat_sharegpt: Dict[str, ParamType]
|
||||
chat_openai: Dict[str, ParamType]
|
||||
instruct: List[Dict[str, ParamType]]
|
||||
chat_sharegpt: List[Dict[str, ParamType]]
|
||||
chat_openai: List[Dict[str, ParamType]]
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
|
@ -96,16 +96,22 @@ EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
|||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
],
|
||||
instruct={
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
chat_sharegpt={
|
||||
ColumnName.conversations.value: StringType(),
|
||||
},
|
||||
chat_openai={
|
||||
ColumnName.messages.value: StringType(),
|
||||
},
|
||||
instruct=[
|
||||
{
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
}
|
||||
],
|
||||
chat_sharegpt=[
|
||||
{
|
||||
ColumnName.conversations.value: StringType(),
|
||||
}
|
||||
],
|
||||
chat_openai=[
|
||||
{
|
||||
ColumnName.messages.value: StringType(),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
|
|
|
@ -126,8 +126,7 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
||||
# self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.seed = 42
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
self.epochs_run = 0
|
||||
self.total_epochs = training_config.n_epochs
|
||||
self._data_format = training_config.data_config.data_format
|
||||
|
@ -140,6 +139,7 @@ class LoraFinetuningSingleDevice:
|
|||
self.global_step = 0
|
||||
|
||||
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||
self.max_validation_steps = training_config.max_validation_steps
|
||||
|
||||
self._clip_grad_norm = 1.0
|
||||
self._enable_activation_checkpointing = (
|
||||
|
@ -366,7 +366,6 @@ class LoraFinetuningSingleDevice:
|
|||
column_map=self._column_map,
|
||||
)
|
||||
data_transform = await utils.get_data_transform(self._data_format)
|
||||
print("data_transform", data_transform.__name__)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=data_transform(
|
||||
|
@ -450,10 +449,6 @@ class LoraFinetuningSingleDevice:
|
|||
|
||||
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
# Shape [b, s], needed for the loss not the model
|
||||
# print("tokens", batch["tokens"])
|
||||
torch.save(batch["tokens"], "/home/markchen1015/new_alpaca_tokens.pth")
|
||||
# print("labels", batch["labels"])
|
||||
torch.save(batch["labels"], "/home/markchen1015/new_alpaca_labels.pth")
|
||||
labels = batch.pop("labels")
|
||||
# run model
|
||||
with self.activations_handling_ctx:
|
||||
|
@ -595,7 +590,7 @@ class LoraFinetuningSingleDevice:
|
|||
log.info("Starting validation...")
|
||||
pbar = tqdm(total=len(self._validation_dataloader))
|
||||
for idx, batch in enumerate(self._validation_dataloader):
|
||||
if idx == 10:
|
||||
if idx == self.max_validation_steps:
|
||||
break
|
||||
torchtune_utils.batch_to_device(batch, self._device)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue