diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index ed15c6de4..134b6bc15 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -16,81 +16,49 @@ from llama_stack.apis.common.job_types import JobStatus from llama_stack.apis.common.training_types import Checkpoint from llama_stack.schema_utils import json_schema_type, register_schema, webmethod - @json_schema_type -class OptimizerType(Enum): - adam = "adam" - adamw = "adamw" - sgd = "sgd" - - -@json_schema_type -class DatasetFormat(Enum): - instruct = "instruct" - dialog = "dialog" - - -@json_schema_type -class DataConfig(BaseModel): - dataset_id: str - batch_size: int - shuffle: bool - data_format: DatasetFormat - validation_dataset_id: Optional[str] = None - packed: Optional[bool] = False - train_on_input: Optional[bool] = False - - -@json_schema_type -class OptimizerConfig(BaseModel): - optimizer_type: OptimizerType - lr: float - weight_decay: float - num_warmup_steps: int - - -@json_schema_type -class EfficiencyConfig(BaseModel): - enable_activation_checkpointing: Optional[bool] = False - enable_activation_offloading: Optional[bool] = False - memory_efficient_fsdp_wrap: Optional[bool] = False - fsdp_cpu_offload: Optional[bool] = False - - -@json_schema_type -class TrainingConfig(BaseModel): - n_epochs: int - max_steps_per_epoch: int - gradient_accumulation_steps: int - max_validation_steps: int - data_config: DataConfig - optimizer_config: OptimizerConfig - efficiency_config: Optional[EfficiencyConfig] = None +class TrainingStrategy(BaseModel): + # params that control Optimizer + learning_rate: Optional[Union[float, Literal["auto"]]] = "auto" + weight_decay: Optional[float] = 0.1 + num_warmup_steps: Optional[Union[int, Literal["auto"]]] = "auto" + + # paramas that control how data is fed for training + batch_size: Optional[Union[int, Literal["auto"]]] = "auto" + shuffle: Optional[bool] = True + n_epochs: Optional[int] = 3 + + # training loop control params + max_training_steps: Optional[int] = None + max_validation_steps: Optional[int] = None + gradient_accumulation_steps: Optional[Union[int, Literal["auto"]]] = "auto" + + # precision for training dtype: Optional[str] = "bf16" @json_schema_type -class LoraFinetuningConfig(BaseModel): +class LoraFinetuningStrategy(BaseModel): type: Literal["LoRA"] = "LoRA" - lora_attn_modules: List[str] - apply_lora_to_mlp: bool - apply_lora_to_output: bool - rank: int - alpha: int + lora_attn_modules: Optional[List[str]] = ["q_proj", "v_proj", "output_proj"] + apply_lora_to_mlp: Optional[bool] = True + apply_lora_to_output: Optional[bool] = False + rank: Optional[int] = 8 + alpha: Optional[int] = 16 use_dora: Optional[bool] = False quantize_base: Optional[bool] = False @json_schema_type -class QATFinetuningConfig(BaseModel): +class QATFinetuningStrategy(BaseModel): type: Literal["QAT"] = "QAT" quantizer_name: str group_size: int -AlgorithmConfig = register_schema( - Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], - name="AlgorithmConfig", +AlgorithmStrategy = register_schema( + Annotated[Union[LoraFinetuningStrategy, QATFinetuningStrategy], Field(discriminator="type")], + name="AlgorithmStrategy", ) @@ -108,35 +76,13 @@ class RLHFAlgorithm(Enum): @json_schema_type -class DPOAlignmentConfig(BaseModel): +class DPOAlignmentStrategy(BaseModel): reward_scale: float reward_clip: float epsilon: float gamma: float -@json_schema_type -class PostTrainingRLHFRequest(BaseModel): - """Request to finetune a model.""" - - job_uuid: str - - finetuned_model: URL - - dataset_id: str - validation_dataset_id: str - - algorithm: RLHFAlgorithm - algorithm_config: DPOAlignmentConfig - - optimizer_config: OptimizerConfig - training_config: TrainingConfig - - # TODO: define these - hyperparam_search_config: Dict[str, Any] - logger_config: Dict[str, Any] - - class PostTrainingJob(BaseModel): job_uuid: str @@ -176,26 +122,32 @@ class PostTraining(Protocol): async def supervised_fine_tune( self, job_uuid: str, - training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + training_dataset_id: str, model: str = Field( default="Llama3.2-3B-Instruct", description="Model descriptor from `llama model list`", ), - checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[AlgorithmConfig] = None, + + # Optional + validation_dataset_id: Optional[str] = None, + training_strategy: Optional[TrainingStrategy] = TrainingStrategy(), + althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(), ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") async def preference_optimize( self, job_uuid: str, - finetuned_model: str, - algorithm_config: DPOAlignmentConfig, - training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + training_dataset_id: str, + model: str = Field( + default="Llama3.2-3B-Instruct", + description="Model descriptor from `llama model list`", + ), + + # Optional + validation_dataset_id: Optional[str] = None, + training_strategy: Optional[TrainingStrategy] = TrainingStrategy(), + althorighm: Optional[AlgorithmStrategy] = LoraFinetuningStrategy(), ) -> PostTrainingJob: ... @webmethod(route="/post-training/jobs", method="GET")