mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-15 17:44:01 +00:00
support llama3.1 8B instruct in post training (#698)
## What does this PR do? - Change to support llama3.1 8B instruct model other than llama3 8B model as llama3.1 8B instruct model is a better model to finetune on top of - Make the copy files logic in checkpointer safer in case the file be copied doesn't exist in source path ## test issue a post training request from client and verify training works as expect <img width="1101" alt="Screenshot 2025-01-02 at 12 18 45 PM" src="https://github.com/user-attachments/assets/47cc4df9-3edc-4afd-b5dd-abe1f039f1ed" /> <img width="782" alt="Screenshot 2025-01-02 at 12 18 52 PM" src="https://github.com/user-attachments/assets/b9435274-ef1d-4570-bd8e-0880c3a4b2e9" />
This commit is contained in:
parent
485476c29a
commit
e86271aeac
2 changed files with 22 additions and 15 deletions
|
@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
|
|||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
|
|
|
@ -21,8 +21,9 @@ from llama_stack.apis.datasets import Datasets
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3 import llama3_tokenizer
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
|
||||
|
@ -49,8 +50,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
"Llama3.1-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_1_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue