mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 20:29:47 +00:00
refine
This commit is contained in:
parent
415b8f2dbd
commit
48482ff9c3
9 changed files with 18 additions and 57 deletions
|
|
@ -62,7 +62,8 @@ def model_checkpoint_dir(model_id) -> str:
|
|||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"Please download model using `llama download --model-id {model_id}`"
|
||||
f"If you try to use the native llama model, Please download model using `llama download --model-id {model_id}`"
|
||||
f"Otherwise, please save you model checkpoint under {model_local_dir(model_id)}"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
|
@ -91,14 +92,9 @@ class Llama:
|
|||
"""
|
||||
llama_model_id = llama_model.core_model_id.value
|
||||
if not torch.distributed.is_initialized():
|
||||
print("I reach torch.distributed.init_process_group")
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = (
|
||||
config.model_parallel_size
|
||||
if config.model_parallel_size
|
||||
else llama_model.pth_file_count
|
||||
)
|
||||
model_parallel_size = llama_model.pth_file_count
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
|
@ -106,8 +102,6 @@ class Llama:
|
|||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
print("torch.cuda.set_device")
|
||||
|
||||
# seed must be the same in all processes
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue