diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index ed15c6de4..8324be5f6 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -17,13 +17,6 @@ from llama_stack.apis.common.training_types import Checkpoint from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -@json_schema_type -class OptimizerType(Enum): - adam = "adam" - adamw = "adamw" - sgd = "sgd" - - @json_schema_type class DatasetFormat(Enum): instruct = "instruct" @@ -33,50 +26,39 @@ class DatasetFormat(Enum): @json_schema_type class DataConfig(BaseModel): dataset_id: str - batch_size: int - shuffle: bool - data_format: DatasetFormat + batch_size: Optional[int] = 1 + shuffle: Optional[bool] = True + data_format: Optional[DatasetFormat] = DatasetFormat.instruct validation_dataset_id: Optional[str] = None - packed: Optional[bool] = False train_on_input: Optional[bool] = False @json_schema_type class OptimizerConfig(BaseModel): - optimizer_type: OptimizerType - lr: float - weight_decay: float - num_warmup_steps: int - - -@json_schema_type -class EfficiencyConfig(BaseModel): - enable_activation_checkpointing: Optional[bool] = False - enable_activation_offloading: Optional[bool] = False - memory_efficient_fsdp_wrap: Optional[bool] = False - fsdp_cpu_offload: Optional[bool] = False + lr: Optional[float] = 2e-5 + weight_decay: Optional[float] = 0.1 + num_warmup_steps: Optional[int] = 20 @json_schema_type class TrainingConfig(BaseModel): - n_epochs: int - max_steps_per_epoch: int - gradient_accumulation_steps: int - max_validation_steps: int data_config: DataConfig optimizer_config: OptimizerConfig - efficiency_config: Optional[EfficiencyConfig] = None + n_epochs: Optional[int] = 1 + max_steps_per_epoch: Optional[int] = None + gradient_accumulation_steps: Optional[int] = 1 + max_validation_steps: Optional[int] = None dtype: Optional[str] = "bf16" @json_schema_type class LoraFinetuningConfig(BaseModel): type: Literal["LoRA"] = "LoRA" - lora_attn_modules: List[str] - apply_lora_to_mlp: bool - apply_lora_to_output: bool - rank: int - alpha: int + lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"] + apply_lora_to_mlp: Optional[bool] = True + apply_lora_to_output: Optional[bool] = False + rank: Optional[int] = 8 + alpha: Optional[int] = 16 use_dora: Optional[bool] = False quantize_base: Optional[bool] = False