This commit is contained in:
Botao Chen 2025-03-09 17:30:23 -07:00
parent 51282456b9
commit 41f86b5ce6

View file

@ -19,19 +19,19 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho
@json_schema_type @json_schema_type
class TrainingStrategy(BaseModel): class TrainingStrategy(BaseModel):
# params that control Optimizer # params that control Optimizer
lr: Optional[float] = 2e-5 lr: Optional[Union[float, Literal["auto"]]] = "auto"
weight_decay: Optional[float] = 0.1 weight_decay: Optional[float] = 0.1
num_warmup_steps: Optional[int] = 0 num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto"
# paramas that control how data is fed for training # paramas that control how data is fed for training
batch_size: Optional[int] = 1 batch_size: Optional[Union[int, Literal["auto"]]] = "auto"
shuffle: Optional[bool] = True shuffle: Optional[bool] = True
n_epochs: Optional[int] = 3 n_epochs: Optional[int] = 3
# training loop control params # training loop control params
max_training_steps: Optional[int] = None max_training_steps: Optional[int] = None
max_validation_steps: Optional[int] = None max_validation_steps: Optional[int] = None
gradient_accumulation_steps: Optional[int] = 1 gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto"
# precision for training # precision for training
dtype: Optional[str] = "bf16" dtype: Optional[str] = "bf16"