From 41f86b5ce643e74478ef23220b6deacaf339e047 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Sun, 9 Mar 2025 17:30:23 -0700 Subject: [PATCH] refine --- llama_stack/apis/post_training/post_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index df05fbf31..8cbc66daf 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -19,19 +19,19 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho @json_schema_type class TrainingStrategy(BaseModel): # params that control Optimizer - lr: Optional[float] = 2e-5 + lr: Optional[Union[float, Literal["auto"]]] = "auto" weight_decay: Optional[float] = 0.1 - num_warmup_steps: Optional[int] = 0 + num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto" # paramas that control how data is fed for training - batch_size: Optional[int] = 1 + batch_size: Optional[Union[int, Literal["auto"]]] = "auto" shuffle: Optional[bool] = True n_epochs: Optional[int] = 3 # training loop control params max_training_steps: Optional[int] = None max_validation_steps: Optional[int] = None - gradient_accumulation_steps: Optional[int] = 1 + gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto" # precision for training dtype: Optional[str] = "bf16"