forked from phoenix-oss/llama-stack-mirror
support llama3.1 8B instruct in post training (#698)
## What does this PR do? - Change to support llama3.1 8B instruct model other than llama3 8B model as llama3.1 8B instruct model is a better model to finetune on top of - Make the copy files logic in checkpointer safer in case the file be copied doesn't exist in source path ## test issue a post training request from client and verify training works as expect <img width="1101" alt="Screenshot 2025-01-02 at 12 18 45 PM" src="https://github.com/user-attachments/assets/47cc4df9-3edc-4afd-b5dd-abe1f039f1ed" /> <img width="782" alt="Screenshot 2025-01-02 at 12 18 52 PM" src="https://github.com/user-attachments/assets/b9435274-ef1d-4570-bd8e-0880c3a4b2e9" />
This commit is contained in:
parent
485476c29a
commit
e86271aeac
2 changed files with 22 additions and 15 deletions
|
@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
|
||||||
model_file_path.mkdir(parents=True, exist_ok=True)
|
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# copy the related files for inference
|
# copy the related files for inference
|
||||||
shutil.copy(
|
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||||
Path.joinpath(self._checkpoint_dir, "params.json"),
|
if source_path.exists():
|
||||||
Path.joinpath(model_file_path, "params.json"),
|
shutil.copy(
|
||||||
)
|
source_path,
|
||||||
shutil.copy(
|
Path.joinpath(model_file_path, "params.json"),
|
||||||
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
)
|
||||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||||
)
|
if source_path.exists():
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
source_path,
|
||||||
Path.joinpath(model_file_path, "orig_params.json"),
|
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||||
)
|
)
|
||||||
|
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||||
|
if source_path.exists():
|
||||||
|
shutil.copy(
|
||||||
|
source_path,
|
||||||
|
Path.joinpath(model_file_path, "orig_params.json"),
|
||||||
|
)
|
||||||
|
|
||||||
if not adapter_only:
|
if not adapter_only:
|
||||||
model_state_dict = state_dict[training.MODEL_KEY]
|
model_state_dict = state_dict[training.MODEL_KEY]
|
||||||
|
|
|
@ -21,8 +21,9 @@ from llama_stack.apis.datasets import Datasets
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
from torchtune.models.llama3 import llama3_tokenizer
|
||||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
|
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,8 +50,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||||
tokenizer_type=llama3_tokenizer,
|
tokenizer_type=llama3_tokenizer,
|
||||||
checkpoint_type="LLAMA3_2",
|
checkpoint_type="LLAMA3_2",
|
||||||
),
|
),
|
||||||
"Llama-3-8B-Instruct": ModelConfig(
|
"Llama3.1-8B-Instruct": ModelConfig(
|
||||||
model_definition=lora_llama3_8b,
|
model_definition=lora_llama3_1_8b,
|
||||||
tokenizer_type=llama3_tokenizer,
|
tokenizer_type=llama3_tokenizer,
|
||||||
checkpoint_type="LLAMA3",
|
checkpoint_type="LLAMA3",
|
||||||
),
|
),
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue