diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index df05fbf31..8cbc66daf 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -19,19 +19,19 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho @json_schema_type class TrainingStrategy(BaseModel): # params that control Optimizer - lr: Optional[float] = 2e-5 + lr: Optional[Union[float, Literal["auto"]]] = "auto" weight_decay: Optional[float] = 0.1 - num_warmup_steps: Optional[int] = 0 + num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto" # paramas that control how data is fed for training - batch_size: Optional[int] = 1 + batch_size: Optional[Union[int, Literal["auto"]]] = "auto" shuffle: Optional[bool] = True n_epochs: Optional[int] = 3 # training loop control params max_training_steps: Optional[int] = None max_validation_steps: Optional[int] = None - gradient_accumulation_steps: Optional[int] = 1 + gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto" # precision for training dtype: Optional[str] = "bf16"