From e86271aeac484f67c4e2ef6e75206f615001c5ac Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Fri, 3 Jan 2025 17:33:05 -0800 Subject: [PATCH] support llama3.1 8B instruct in post training (#698) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? - Change to support llama3.1 8B instruct model other than llama3 8B model as llama3.1 8B instruct model is a better model to finetune on top of - Make the copy files logic in checkpointer safer in case the file be copied doesn't exist in source path ## test issue a post training request from client and verify training works as expect Screenshot 2025-01-02 at 12 18 45 PM Screenshot 2025-01-02 at 12 18 52 PM --- .../torchtune/common/checkpointer.py | 30 +++++++++++-------- .../post_training/torchtune/common/utils.py | 7 +++-- 2 files changed, 22 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 688a03c25..359fc43ca 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -90,18 +90,24 @@ class TorchtuneCheckpointer: model_file_path.mkdir(parents=True, exist_ok=True) # copy the related files for inference - shutil.copy( - Path.joinpath(self._checkpoint_dir, "params.json"), - Path.joinpath(model_file_path, "params.json"), - ) - shutil.copy( - Path.joinpath(self._checkpoint_dir, "tokenizer.model"), - Path.joinpath(model_file_path, "tokenizer.model"), - ) - shutil.copy( - Path.joinpath(self._checkpoint_dir, "orig_params.json"), - Path.joinpath(model_file_path, "orig_params.json"), - ) + source_path = Path.joinpath(self._checkpoint_dir, "params.json") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "params.json"), + ) + source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "tokenizer.model"), + ) + source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json") + if source_path.exists(): + shutil.copy( + source_path, + Path.joinpath(model_file_path, "orig_params.json"), + ) if not adapter_only: model_state_dict = state_dict[training.MODEL_KEY] diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index a5279cdbe..2b7a4ec93 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -21,8 +21,9 @@ from llama_stack.apis.datasets import Datasets from pydantic import BaseModel -from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b +from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer +from torchtune.models.llama3_1 import lora_llama3_1_8b from torchtune.models.llama3_2 import lora_llama3_2_3b @@ -49,8 +50,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = { tokenizer_type=llama3_tokenizer, checkpoint_type="LLAMA3_2", ), - "Llama-3-8B-Instruct": ModelConfig( - model_definition=lora_llama3_8b, + "Llama3.1-8B-Instruct": ModelConfig( + model_definition=lora_llama3_1_8b, tokenizer_type=llama3_tokenizer, checkpoint_type="LLAMA3", ),