mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
temp commit
This commit is contained in:
parent
7c1e3daa75
commit
82d881c95b
2 changed files with 22 additions and 15 deletions
|
@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
|
|||
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# copy the related files for inference
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
shutil.copy(
|
||||
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "params.json"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||
)
|
||||
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
|
||||
if source_path.exists():
|
||||
shutil.copy(
|
||||
source_path,
|
||||
Path.joinpath(model_file_path, "orig_params.json"),
|
||||
)
|
||||
|
||||
if not adapter_only:
|
||||
model_state_dict = state_dict[training.MODEL_KEY]
|
||||
|
|
|
@ -19,8 +19,9 @@ from llama_models.sku_list import resolve_model
|
|||
from llama_stack.apis.common.type_system import ParamType, StringType
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3 import llama3_tokenizer
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_1 import lora_llama3_1_8b
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
|
||||
|
@ -47,8 +48,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
|||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
"Llama3.1-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_1_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue