From 0751a960a518785a821407bee4b855fbf56e88cb Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Sat, 12 Apr 2025 04:13:45 -0400 Subject: [PATCH] feat: make training config fields optional (#1861) # What does this PR do? Today, supervised_fine_tune itself and the `TrainingConfig` class have a bunch of required fields that a provider implementation might not need. for example, if a provider wants to handle hyperparameters in its configuration as well as any type of dataset retrieval, optimizer or LoRA config, a user will still need to pass in a virtually empty `DataConfig`, `OptimizerConfig` and `AlgorithmConfig` in some cases. Many of these fields are intended to work specifically with llama models and knobs intended for customizing inline. Adding remote post_training providers will require loosening these arguments, or forcing users to pass in empty objects to satisfy the pydantic models. Signed-off-by: Charlie Doern --- docs/_static/llama-stack-spec.html | 17 ++++++++--------- docs/_static/llama-stack-spec.yaml | 7 +++---- llama_stack/apis/post_training/post_training.py | 16 ++++++++-------- .../recipes/lora_finetuning_single_device.py | 10 ++++++++++ 4 files changed, 29 insertions(+), 21 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 36bfad49e..cdd6b3b53 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9778,13 +9778,16 @@ "type": "integer" }, "max_steps_per_epoch": { - "type": "integer" + "type": "integer", + "default": 1 }, "gradient_accumulation_steps": { - "type": "integer" + "type": "integer", + "default": 1 }, "max_validation_steps": { - "type": "integer" + "type": "integer", + "default": 1 }, "data_config": { "$ref": "#/components/schemas/DataConfig" @@ -9804,10 +9807,7 @@ "required": [ "n_epochs", "max_steps_per_epoch", - "gradient_accumulation_steps", - "max_validation_steps", - "data_config", - "optimizer_config" + "gradient_accumulation_steps" ], "title": "TrainingConfig" }, @@ -10983,8 +10983,7 @@ "job_uuid", "training_config", "hyperparam_search_config", - "logger_config", - "model" + "logger_config" ], "title": "SupervisedFineTuneRequest" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 82faf450a..aa8d9456e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6744,10 +6744,13 @@ components: type: integer max_steps_per_epoch: type: integer + default: 1 gradient_accumulation_steps: type: integer + default: 1 max_validation_steps: type: integer + default: 1 data_config: $ref: '#/components/schemas/DataConfig' optimizer_config: @@ -6762,9 +6765,6 @@ components: - n_epochs - max_steps_per_epoch - gradient_accumulation_steps - - max_validation_steps - - data_config - - optimizer_config title: TrainingConfig PreferenceOptimizeRequest: type: object @@ -7498,7 +7498,6 @@ components: - training_config - hyperparam_search_config - logger_config - - model title: SupervisedFineTuneRequest SyntheticDataGenerateRequest: type: object diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index d49668e23..e5f1bcb65 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel): @json_schema_type class TrainingConfig(BaseModel): n_epochs: int - max_steps_per_epoch: int - gradient_accumulation_steps: int - max_validation_steps: int - data_config: DataConfig - optimizer_config: OptimizerConfig + max_steps_per_epoch: int = 1 + gradient_accumulation_steps: int = 1 + max_validation_steps: Optional[int] = 1 + data_config: Optional[DataConfig] = None + optimizer_config: Optional[OptimizerConfig] = None efficiency_config: Optional[EfficiencyConfig] = None dtype: Optional[str] = "bf16" @@ -177,9 +177,9 @@ class PostTraining(Protocol): training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], - model: str = Field( - default="Llama3.2-3B-Instruct", - description="Model descriptor from `llama model list`", + model: Optional[str] = Field( + default=None, + description="Model descriptor for training if not in provider config`", ), checkpoint_dir: Optional[str] = None, algorithm_config: Optional[AlgorithmConfig] = None, diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index edc1ceb90..04bf86b97 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -38,6 +38,8 @@ from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( Checkpoint, + DataConfig, + EfficiencyConfig, LoraFinetuningConfig, OptimizerConfig, QATFinetuningConfig, @@ -89,6 +91,10 @@ class LoraFinetuningSingleDevice: datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: + assert isinstance(training_config.data_config, DataConfig), "DataConfig must be initialized" + + assert isinstance(training_config.efficiency_config, EfficiencyConfig), "EfficiencyConfig must be initialized" + self.job_uuid = job_uuid self.training_config = training_config if not isinstance(algorithm_config, LoraFinetuningConfig): @@ -188,6 +194,7 @@ class LoraFinetuningSingleDevice: self._tokenizer = await self._setup_tokenizer() log.info("Tokenizer is initialized.") + assert isinstance(self.training_config.optimizer_config, OptimizerConfig), "OptimizerConfig must be initialized" self._optimizer = await self._setup_optimizer(optimizer_config=self.training_config.optimizer_config) log.info("Optimizer is initialized.") @@ -195,6 +202,8 @@ class LoraFinetuningSingleDevice: self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) log.info("Loss is initialized.") + assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized" + self._training_sampler, self._training_dataloader = await self._setup_data( dataset_id=self.training_config.data_config.dataset_id, tokenizer=self._tokenizer, @@ -452,6 +461,7 @@ class LoraFinetuningSingleDevice: """ The core training loop. """ + assert isinstance(self.training_config.data_config, DataConfig), "DataConfig must be initialized" # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() running_loss: float = 0.0