temp commit

This commit is contained in:
Botao Chen 2024-11-25 17:27:26 -08:00
parent 900b0556e7
commit d7598c68d7
6 changed files with 491 additions and 3 deletions

View file

@ -16,6 +16,7 @@ from pydantic import BaseModel, Field
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403
import torch
class OptimizerType(Enum):
@ -30,18 +31,22 @@ class OptimizerConfig(BaseModel):
lr: float
lr_min: float
weight_decay: float
num_warmup_steps: int
@json_schema_type
class TrainingConfig(BaseModel):
dtype: torch.dtype
n_epochs: int
max_steps_per_epoch: int
gradient_accumulation_steps: int
batch_size: int
shuffle: bool
n_iters: int
# n_iters: int
enable_activation_checkpointing: bool
memory_efficient_fsdp_wrap: bool
fsdp_cpu_offload: bool
memory_efficient_fsdp_wrap: Optional[bool]
fsdp_cpu_offload: Optional[bool]
@json_schema_type