diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index f593749e2..758817ac0 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -182,12 +182,8 @@ class PostTraining(Protocol): async def preference_optimize( self, job_uuid: str, - finetuned_model: URL, - dataset_id: str, - validation_dataset_id: str, - algorithm: RLHFAlgorithm, + finetuned_model: str, algorithm_config: DPOAlignmentConfig, - optimizer_config: OptimizerConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], diff --git a/llama_stack/providers/inline/post_training/torchtune/__init__.py b/llama_stack/providers/inline/post_training/torchtune/__init__.py index a756385de..247ae22b2 100644 --- a/llama_stack/providers/inline/post_training/torchtune/__init__.py +++ b/llama_stack/providers/inline/post_training/torchtune/__init__.py @@ -10,6 +10,8 @@ from llama_stack.distribution.datatypes import Api, ProviderSpec from .config import TorchtunePostTrainingConfig +# post_training api and the torchtune provider is still experimental and under heavy development + async def get_provider_impl( config: TorchtunePostTrainingConfig, @@ -21,5 +23,4 @@ async def get_provider_impl( config, deps[Api.datasetio], ) - # await impl.initialize() return impl diff --git a/llama_stack/providers/inline/post_training/torchtune/post_training.py b/llama_stack/providers/inline/post_training/torchtune/post_training.py index 74124325a..f33ca059a 100644 --- a/llama_stack/providers/inline/post_training/torchtune/post_training.py +++ b/llama_stack/providers/inline/post_training/torchtune/post_training.py @@ -51,12 +51,8 @@ class TorchtunePostTrainingImpl: async def preference_optimize( self, job_uuid: str, - finetuned_model: URL, - dataset_id: str, - validation_dataset_id: str, - algorithm: RLHFAlgorithm, + finetuned_model: str, algorithm_config: DPOAlignmentConfig, - optimizer_config: OptimizerConfig, training_config: TrainingConfig, hyperparam_search_config: Dict[str, Any], logger_config: Dict[str, Any], diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index c97651a34..17d3cbc2c 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -57,8 +57,8 @@ class LoraFinetuningSingleDevice: # Resume from checkpoint hasn't been supported yet # Validation hasn't been supported yet - # TODO @markchen1015 figure out the logging for this training recipe - # and make it work with telemetry + # Currently logging only logs limited training metrics to local disk + # will figure out more loggings and how it works with telemetry in future PRs def __init__( self, config: TorchtunePostTrainingConfig, @@ -70,7 +70,6 @@ class LoraFinetuningSingleDevice: algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]], datasetio_api: DatasetIO, ) -> None: - # Assume the training only happens on GPU self.training_config = training_config self.algorithm_config = algorithm_config self._device = torchtune_utils.get_device(device="cuda") @@ -149,7 +148,6 @@ class LoraFinetuningSingleDevice: return checkpoint_dict async def setup(self) -> None: - # temporily log to local disk, will figure out how to interop with telemetry self._metric_logger = DiskLogger(log_dir=self._output_dir) checkpoint_dict = await self.load_checkpoint() @@ -316,7 +314,7 @@ class LoraFinetuningSingleDevice: all_rows = await fetch_rows() rows = all_rows.rows - # Curretly only support instruct dataset + # Curretly only support alpaca instruct dataset # TODO @markchen1015 make the message_transform swappable and support more dataset types ds = SFTDataset( rows,