This commit is contained in:
Botao Chen 2025-03-09 17:30:23 -07:00
parent 51282456b9
commit 41f86b5ce6

View file

@ -19,19 +19,19 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type
class TrainingStrategy(BaseModel):
# params that control Optimizer
lr: Optional[float] = 2e-5
lr: Optional[Union[float, Literal["auto"]]] = "auto"
weight_decay: Optional[float] = 0.1
num_warmup_steps: Optional[int] = 0
num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto"
# paramas that control how data is fed for training
batch_size: Optional[int] = 1
batch_size: Optional[Union[int, Literal["auto"]]] = "auto"
shuffle: Optional[bool] = True
n_epochs: Optional[int] = 3
# training loop control params
max_training_steps: Optional[int] = None
max_validation_steps: Optional[int] = None
gradient_accumulation_steps: Optional[int] = 1
gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto"
# precision for training
dtype: Optional[str] = "bf16"