temp commit

This commit is contained in:
Botao Chen 2024-12-16 21:43:30 -08:00
parent 30f6eb282f
commit 81e1957446
10 changed files with 54 additions and 39 deletions

View file

@ -25,12 +25,12 @@ from fairscale.nn.model_parallel.initialize import (
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403
@ -53,16 +53,16 @@ from .config import (
log = logging.getLogger(__name__)
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
def model_checkpoint_dir(model_id) -> str:
checkpoint_dir = Path(model_local_dir(model_id))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
f"Please download model using `llama download --model-id {model_id}`"
)
return str(checkpoint_dir)
@ -80,6 +80,7 @@ class Llama:
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
],
model_id: str,
llama_model: Model,
):
"""
Build a Llama instance by initializing and loading a model checkpoint.
@ -88,14 +89,16 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(model_id)
llama_model = model.core_model_id.value
llama_model_id = llama_model.core_model_id.value
if not torch.distributed.is_initialized():
print("I reach torch.distributed.init_process_group")
torch.distributed.init_process_group("nccl")
model_parallel_size = config.model_parallel_size
model_parallel_size = (
config.model_parallel_size
if config.model_parallel_size
else llama_model.pth_file_count
)
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
@ -103,6 +106,8 @@ class Llama:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
print("torch.cuda.set_device")
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
@ -114,7 +119,7 @@ class Llama:
if config.checkpoint_dir and config.checkpoint_dir != "null":
ckpt_dir = config.checkpoint_dir
else:
ckpt_dir = model_checkpoint_dir(model)
ckpt_dir = model_checkpoint_dir(model_id) # true model id
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
@ -190,7 +195,7 @@ class Llama:
model.load_state_dict(state_dict, strict=False)
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args, llama_model)
return Llama(model, tokenizer, model_args, llama_model_id)
def __init__(
self,