temp commit

This commit is contained in:
Botao Chen 2024-12-30 16:39:01 -08:00
parent 7c1e3daa75
commit 82d881c95b
2 changed files with 22 additions and 15 deletions

View file

@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
model_file_path.mkdir(parents=True, exist_ok=True) model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference # copy the related files for inference
shutil.copy( source_path = Path.joinpath(self._checkpoint_dir, "params.json")
Path.joinpath(self._checkpoint_dir, "params.json"), if source_path.exists():
Path.joinpath(model_file_path, "params.json"), shutil.copy(
) source_path,
shutil.copy( Path.joinpath(model_file_path, "params.json"),
Path.joinpath(self._checkpoint_dir, "tokenizer.model"), )
Path.joinpath(model_file_path, "tokenizer.model"), source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
) if source_path.exists():
shutil.copy( shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"), source_path,
Path.joinpath(model_file_path, "orig_params.json"), Path.joinpath(model_file_path, "tokenizer.model"),
) )
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only: if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY] model_state_dict = state_dict[training.MODEL_KEY]

View file

@ -19,8 +19,9 @@ from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b from torchtune.models.llama3_2 import lora_llama3_2_3b
@ -47,8 +48,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
tokenizer_type=llama3_tokenizer, tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2", checkpoint_type="LLAMA3_2",
), ),
"Llama-3-8B-Instruct": ModelConfig( "Llama3.1-8B-Instruct": ModelConfig(
model_definition=lora_llama3_8b, model_definition=lora_llama3_1_8b,
tokenizer_type=llama3_tokenizer, tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3", checkpoint_type="LLAMA3",
), ),