This commit is contained in:
Botao Chen 2024-12-17 13:38:19 -08:00
parent 415b8f2dbd
commit 48482ff9c3
9 changed files with 18 additions and 57 deletions

View file

@ -62,7 +62,8 @@ def model_checkpoint_dir(model_id) -> str:
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"Please download model using `llama download --model-id {model_id}`"
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
)
return str(checkpoint_dir)
@ -91,14 +92,9 @@ class Llama:
"""
llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized():
print("I reach torch.distributed.init_process_group")
torch.distributed.init_process_group("nccl")
model_parallel_size = (
config.model_parallel_size
if config.model_parallel_size
else llama_model.pth_file_count
)
model_parallel_size = llama_model.pth_file_count
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
@ -106,8 +102,6 @@ class Llama:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
print("torch.cuda.set_device")
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)