From 840fd353f73ac73dd1e6b8cb8d4d56683565f416 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 3 Mar 2025 23:41:11 -0800 Subject: [PATCH 1/5] init commit --- .../apis/post_training/post_training.py | 48 ++++++------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index ed15c6de4..8324be5f6 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -17,13 +17,6 @@ from llama_stack.apis.common.training_types import Checkpoint from llama_stack.schema_utils import json_schema_type, register_schema, webmethod -@json_schema_type -class OptimizerType(Enum): - adam = "adam" - adamw = "adamw" - sgd = "sgd" - - @json_schema_type class DatasetFormat(Enum): instruct = "instruct" @@ -33,50 +26,39 @@ class DatasetFormat(Enum): @json_schema_type class DataConfig(BaseModel): dataset_id: str - batch_size: int - shuffle: bool - data_format: DatasetFormat + batch_size: Optional[int] = 1 + shuffle: Optional[bool] = True + data_format: Optional[DatasetFormat] = DatasetFormat.instruct validation_dataset_id: Optional[str] = None - packed: Optional[bool] = False train_on_input: Optional[bool] = False @json_schema_type class OptimizerConfig(BaseModel): - optimizer_type: OptimizerType - lr: float - weight_decay: float - num_warmup_steps: int - - -@json_schema_type -class EfficiencyConfig(BaseModel): - enable_activation_checkpointing: Optional[bool] = False - enable_activation_offloading: Optional[bool] = False - memory_efficient_fsdp_wrap: Optional[bool] = False - fsdp_cpu_offload: Optional[bool] = False + lr: Optional[float] = 2e-5 + weight_decay: Optional[float] = 0.1 + num_warmup_steps: Optional[int] = 20 @json_schema_type class TrainingConfig(BaseModel): - n_epochs: int - max_steps_per_epoch: int - gradient_accumulation_steps: int - max_validation_steps: int data_config: DataConfig optimizer_config: OptimizerConfig - efficiency_config: Optional[EfficiencyConfig] = None + n_epochs: Optional[int] = 1 + max_steps_per_epoch: Optional[int] = None + gradient_accumulation_steps: Optional[int] = 1 + max_validation_steps: Optional[int] = None dtype: Optional[str] = "bf16" @json_schema_type class LoraFinetuningConfig(BaseModel): type: Literal["LoRA"] = "LoRA" - lora_attn_modules: List[str] - apply_lora_to_mlp: bool - apply_lora_to_output: bool - rank: int - alpha: int + lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"] + apply_lora_to_mlp: Optional[bool] = True + apply_lora_to_output: Optional[bool] = False + rank: Optional[int] = 8 + alpha: Optional[int] = 16 use_dora: Optional[bool] = False quantize_base: Optional[bool] = False From 0b92fef3ba70bbee26580dd783979b9e71339436 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Mon, 3 Mar 2025 23:53:27 -0800 Subject: [PATCH 2/5] refine --- llama_stack/apis/post_training/post_training.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8324be5f6..d3fe56fb0 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -43,7 +43,7 @@ class OptimizerConfig(BaseModel): @json_schema_type class TrainingConfig(BaseModel): data_config: DataConfig - optimizer_config: OptimizerConfig + optimizer_config: Optional[OptimizerConfig] = OptimizerConfig() n_epochs: Optional[int] = 1 max_steps_per_epoch: Optional[int] = None gradient_accumulation_steps: Optional[int] = 1 @@ -159,8 +159,6 @@ class PostTraining(Protocol): self, job_uuid: str, training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], model: str = Field( default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", From 51282456b9e24b2e216ecd50dc55155483244498 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Sun, 9 Mar 2025 17:17:18 -0700 Subject: [PATCH 3/5] refine --- .../apis/post_training/post_training.py | 100 +++++++----------- 1 file changed, 36 insertions(+), 64 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index d3fe56fb0..df05fbf31 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -16,43 +16,29 @@ from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.training_types import Checkpoint from llama_stack.schema_utils import json_schema_type, register_schema, webmethod - @json_schema_type -class DatasetFormat(Enum): - instruct = "instruct" - dialog = "dialog" - - -@json_schema_type -class DataConfig(BaseModel): - dataset_id: str - batch_size: Optional[int] = 1 - shuffle: Optional[bool] = True - data_format: Optional[DatasetFormat] = DatasetFormat.instruct - validation_dataset_id: Optional[str] = None - train_on_input: Optional[bool] = False - - -@json_schema_type -class OptimizerConfig(BaseModel): +class TrainingStrategy(BaseModel): + # params that control Optimizer lr: Optional[float] = 2e-5 weight_decay: Optional[float] = 0.1 - num_warmup_steps: Optional[int] = 20 - - -@json_schema_type -class TrainingConfig(BaseModel): - data_config: DataConfig - optimizer_config: Optional[OptimizerConfig] = OptimizerConfig() - n_epochs: Optional[int] = 1 - max_steps_per_epoch: Optional[int] = None - gradient_accumulation_steps: Optional[int] = 1 + num_warmup_steps: Optional[int] = 0 + + # paramas that control how data is fed for training + batch_size: Optional[int] = 1 + shuffle: Optional[bool] = True + n_epochs: Optional[int] = 3 + + # training loop control params + max_training_steps: Optional[int] = None max_validation_steps: Optional[int] = None + gradient_accumulation_steps: Optional[int] = 1 + + # precision for training dtype: Optional[str] = "bf16" @json_schema_type -class LoraFinetuningConfig(BaseModel): +class LoraFinetuningStrategy(BaseModel): type: Literal["LoRA"] = "LoRA" lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"] apply_lora_to_mlp: Optional[bool] = True @@ -64,15 +50,15 @@ class LoraFinetuningConfig(BaseModel): @json_schema_type -class QATFinetuningConfig(BaseModel): +class QATFinetuningStrategy(BaseModel): type: Literal["QAT"] = "QAT" quantizer_name: str group_size: int -AlgorithmConfig = register_schema( - Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], - name="AlgorithmConfig", +AlgorithmStrategy = register_schema( + Annotated[Union[LoraFinetuningStrategy, QATFinetuningStrategy], Field(discriminator="type")], + name="AlgorithmStrategy", ) @@ -90,35 +76,13 @@ class RLHFAlgorithm(Enum): @json_schema_type -class DPOAlignmentConfig(BaseModel): +class DPOAlignmentStrategy(BaseModel): reward_scale: float reward_clip: float epsilon: float gamma: float -@json_schema_type -class PostTrainingRLHFRequest(BaseModel): - """Request to finetune a model.""" - - job_uuid: str - - finetuned_model: URL - - dataset_id: str - validation_dataset_id: str - - algorithm: RLHFAlgorithm - algorithm_config: DPOAlignmentConfig - - optimizer_config: OptimizerConfig - training_config: TrainingConfig - - # TODO: define these - hyperparam_search_config: Dict[str, Any] - logger_config: Dict[str, Any] - - class PostTrainingJob(BaseModel): job_uuid: str @@ -158,24 +122,32 @@ class PostTraining(Protocol): async def supervised_fine_tune( self, job_uuid: str, - training_config: TrainingConfig, + training_dataset_id: str, model: str = Field( default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", ), - checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[AlgorithmConfig] = None, + + # Optional + validation_dataset_id: Optional[str] = None, + training_strategy: Optional[TrainingStrategy] = TrainingStrategy(), + althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(), ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") async def preference_optimize( self, job_uuid: str, - finetuned_model: str, - algorithm_config: DPOAlignmentConfig, - training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + training_dataset_id: str, + model: str = Field( + default="Llama3.2-3B-Instruct", + description="Model descriptor from `llama model list`", + ), + + # Optional + validation_dataset_id: Optional[str] = None, + training_strategy: Optional[TrainingStrategy] = TrainingStrategy(), + althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(), ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs", method="GET") From 41f86b5ce643e74478ef23220b6deacaf339e047 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Sun, 9 Mar 2025 17:30:23 -0700 Subject: [PATCH 4/5] refine --- llama_stack/apis/post_training/post_training.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index df05fbf31..8cbc66daf 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -19,19 +19,19 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho @json_schema_type class TrainingStrategy(BaseModel): # params that control Optimizer - lr: Optional[float] = 2e-5 + lr: Optional[Union[float, Literal["auto"]]] = "auto" weight_decay: Optional[float] = 0.1 - num_warmup_steps: Optional[int] = 0 + num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto" # paramas that control how data is fed for training - batch_size: Optional[int] = 1 + batch_size: Optional[Union[int, Literal["auto"]]] = "auto" shuffle: Optional[bool] = True n_epochs: Optional[int] = 3 # training loop control params max_training_steps: Optional[int] = None max_validation_steps: Optional[int] = None - gradient_accumulation_steps: Optional[int] = 1 + gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto" # precision for training dtype: Optional[str] = "bf16" From 357141f6de171db6f2d65fd66f950896207a0e7b Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Sun, 9 Mar 2025 18:14:26 -0700 Subject: [PATCH 5/5] refine --- llama_stack/apis/post_training/post_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8cbc66daf..134b6bc15 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -19,7 +19,7 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho @json_schema_type class TrainingStrategy(BaseModel): # params that control Optimizer - lr: Optional[Union[float, Literal["auto"]]] = "auto" + learning_rate: Optional[Union[float, Literal["auto"]]] = "auto" weight_decay: Optional[float] = 0.1 num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto"