From 18ae57776060ccaf3c4065e69995d34e45372b0f Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 27 Nov 2024 14:35:01 -0800 Subject: [PATCH] temp commit --- .../apis/post_training/post_training.py | 20 +++++++++--------- .../meta_reference/post_training.py | 21 ++++++++++--------- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index d851d7858..b218d22e5 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -183,16 +183,16 @@ class PostTraining(Protocol): @webmethod(route="/post-training/supervised-fine-tune") def supervised_fine_tune( self, - job_uuid: str, - model: str, - dataset_id: str, - validation_dataset_id: str, - algorithm: FinetuningAlgorithm, - algorithm_config: LoraFinetuningConfig, - optimizer_config: OptimizerConfig, - training_config: TrainingConfig, - hyperparam_search_config: Dict[str, Any], - logger_config: Dict[str, Any], + job_uuid: Optional[str], + model: Optional[str], + dataset_id: Optional[str], + validation_dataset_id: Optional[str], + algorithm: Optional[FinetuningAlgorithm], + algorithm_config: Optional[LoraFinetuningConfig], + optimizer_config: Optional[OptimizerConfig], + training_config: Optional[TrainingConfig], + hyperparam_search_config: Optional[Dict[str, Any]], + logger_config: Optional[Dict[str, Any]], ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize") diff --git a/llama_stack/providers/inline/post_training/meta_reference/post_training.py b/llama_stack/providers/inline/post_training/meta_reference/post_training.py index aea069cb0..5e89b3479 100644 --- a/llama_stack/providers/inline/post_training/meta_reference/post_training.py +++ b/llama_stack/providers/inline/post_training/meta_reference/post_training.py @@ -30,6 +30,7 @@ class MetaReferencePostTrainingImpl: ) OptimizerConfig( + optimizer_type=OptimizerType.adamw, lr=3e-4, lr_min=3e-5, weight_decay=0.1, @@ -50,16 +51,16 @@ class MetaReferencePostTrainingImpl: def supervised_fine_tune( self, - job_uuid: str = "1234", - model: str = " meta-llama/Llama-3.2-3B-Instruct", - dataset_id: str = "alpaca", - validation_dataset_id: str = "alpaca", - algorithm: FinetuningAlgorithm = FinetuningAlgorithm.lora, - algorithm_config: LoraFinetuningConfig = LoraFinetuningConfig, - optimizer_config: OptimizerConfig = OptimizerConfig, - training_config: TrainingConfig = TrainingConfig, - hyperparam_search_config: Dict[str, Any] = {}, - logger_config: Dict[str, Any] = {}, + job_uuid: Optional[str] = "1234", + model: Optional[str] = " meta-llama/Llama-3.2-3B-Instruct", + dataset_id: Optional[str] = "alpaca", + validation_dataset_id: Optional[str] = "alpaca", + algorithm: Optional[FinetuningAlgorithm] = FinetuningAlgorithm.lora, + algorithm_config: Optional[LoraFinetuningConfig] = LoraFinetuningConfig, + optimizer_config: Optional[OptimizerConfig] = OptimizerConfig, + training_config: Optional[TrainingConfig] = TrainingConfig, + hyperparam_search_config: Optional[Dict[str, Any]] = {}, + logger_config: Optional[Dict[str, Any]] = {}, ) -> PostTrainingJob: # wrapper request to make it easier to pass around (internal only, not exposed to API) request = PostTrainingSFTRequest(