diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 8e2d131c0..d1e4d30e7 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -180,16 +180,16 @@ class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune") def supervised_fine_tune( self, - job_uuid: Optional[str], - model: Optional[str], - dataset_id: Optional[str], - validation_dataset_id: Optional[str], - algorithm: Optional[FinetuningAlgorithm], - algorithm_config: Optional[LoraFinetuningConfig], - optimizer_config: Optional[OptimizerConfig], - training_config: Optional[TrainingConfig], - hyperparam_search_config: Optional[Dict[str, Any]], - logger_config: Optional[Dict[str, Any]], + job_uuid: str, + model: str, + dataset_id: str, + validation_dataset_id: str, + algorithm: FinetuningAlgorithm, + algorithm_config: LoraFinetuningConfig, + optimizer_config: OptimizerConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") diff --git a/llama_stack/providers/inline/post_training/meta_reference/post_training.py b/llama_stack/providers/inline/post_training/meta_reference/post_training.py index 5f7a70742..ca5ef67e5 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/post_training.py +++ b/llama_stack/providers/inline/post_training/meta_reference/post_training.py @@ -20,46 +20,18 @@ class MetaReferencePostTrainingImpl: self.config = config self.datasetio_api = datasetio_api - LoraFinetuningConfig( - lora_attn_modules=["q_proj", "v_proj", "output_proj"], - apply_lora_to_mlp=True, - apply_lora_to_output=False, - rank=8, - alpha=16, - ) - - OptimizerConfig( - optimizer_type=OptimizerType.adamw, - lr=3e-4, - lr_min=3e-5, - weight_decay=0.1, - num_warmup_steps=100, - ) - - TrainingConfig( - dtype="bf16", - n_epochs=1, - max_steps_per_epoch=10, - gradient_accumulation_steps=1, - batch_size=1, - shuffle=1, - enable_activation_checkpointing=False, - memory_efficient_fsdp_wrap=False, - fsdp_cpu_offload=False, - ) - def supervised_fine_tune( self, - job_uuid: Optional[str] = "1234", - model: Optional[str] = " meta-llama/Llama-3.2-3B-Instruct", - dataset_id: Optional[str] = "alpaca", - validation_dataset_id: Optional[str] = "alpaca", - algorithm: Optional[FinetuningAlgorithm] = FinetuningAlgorithm.lora, - algorithm_config: Optional[LoraFinetuningConfig] = LoraFinetuningConfig, - optimizer_config: Optional[OptimizerConfig] = OptimizerConfig, - training_config: Optional[TrainingConfig] = TrainingConfig, - hyperparam_search_config: Optional[Dict[str, Any]] = {}, - logger_config: Optional[Dict[str, Any]] = {}, + job_uuid: str, + model: str, + dataset_id: str, + validation_dataset_id: str, + algorithm: FinetuningAlgorithm, + algorithm_config: LoraFinetuningConfig, + optimizer_config: OptimizerConfig, + training_config: TrainingConfig, + hyperparam_search_config: Dict[str, Any], + logger_config: Dict[str, Any], ) -> PostTrainingJob: # wrapper request to make it easier to pass around (internal only, not exposed to API) request = PostTrainingSFTRequest( @@ -71,6 +43,7 @@ class MetaReferencePostTrainingImpl: algorithm_config=algorithm_config, optimizer_config=optimizer_config, training_config=training_config, + hyperparam_search_config=hyperparam_search_config, logger_config=logger_config, ) if request.algorithm == FinetuningAlgorithm.lora: diff --git a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py index 7bf99ad21..047d2805a 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/meta_reference/recipes/lora_finetuning_single_device.py @@ -13,6 +13,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch from llama_stack.apis.datasetio import DatasetIO from torch import nn +from torchtune import utils as torchtune_utils from llama_stack.apis.post_training import * # noqa from llama_stack.apis.post_training import PostTrainingSFTRequest @@ -56,14 +57,14 @@ class LoraFinetuningSingleDevice: # self._device = utils.get_device(device=cfg.device) self.config = config self.request = request - self._device = training.utils.get_device(device="cuda") + self._device = torchtune_utils.get_device(device="cuda") self._dtype = training.get_dtype( request.training_config.dtype, device=self._device ) - self.model_id = request.model + self.model_id = config.model # hardcode it for now and see how it works with get_training_job_artifacts - self._output_dir = f"~/.llama/checkpoints/post_training/{request.model_id}" + self._output_dir = f"~/.llama/checkpoints/post_training/{self.model_id}" self._log_every_n_steps = 1 self._log_peak_memory_stats = False @@ -111,8 +112,8 @@ class LoraFinetuningSingleDevice: return [f"Error: The directory '{checkpoint_dir}' does not exist."] self._checkpointer = training.FullModelMetaCheckpointer( - checkpoint_dir=self.config.checkpoint_dir, - checkpoint_files=get_checkpoint_files, + checkpoint_dir=self.checkpoint_dir, + checkpoint_files=get_checkpoint_files(self.checkpoint_dir), output_dir=self._output_dir, # todo: automatically get this info from model model_type="LLAMA3", @@ -449,7 +450,7 @@ class LoraFinetuningSingleDevice: # ): # torch.cuda.memory._record_memory_history() - training.utils.batch_to_device(batch, self._device) + torchtune_utils.batch_to_device(batch, self._device) # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step