mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 21:59:49 +00:00
temp commit
This commit is contained in:
parent
30f6eb282f
commit
81e1957446
10 changed files with 54 additions and 39 deletions
|
|
@ -25,12 +25,12 @@ from fairscale.nn.model_parallel.initialize import (
|
|||
)
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
|
||||
from llama_models.llama3.api.datatypes import Model
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.llama3.reference_impl.model import Transformer
|
||||
from llama_models.llama3.reference_impl.multimodal.model import (
|
||||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
|
@ -53,16 +53,16 @@ from .config import (
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
|
||||
if not any(p.exists() for p in paths):
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), (
|
||||
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||
f"Could not find checkpoints in: {model_local_dir(model_id)}. "
|
||||
f"Please download model using `llama download --model-id {model_id}`"
|
||||
)
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
|
@ -80,6 +80,7 @@ class Llama:
|
|||
MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
],
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
|
@ -88,14 +89,16 @@ class Llama:
|
|||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
model = resolve_model(model_id)
|
||||
|
||||
llama_model = model.core_model_id.value
|
||||
|
||||
llama_model_id = llama_model.core_model_id.value
|
||||
if not torch.distributed.is_initialized():
|
||||
print("I reach torch.distributed.init_process_group")
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = config.model_parallel_size
|
||||
model_parallel_size = (
|
||||
config.model_parallel_size
|
||||
if config.model_parallel_size
|
||||
else llama_model.pth_file_count
|
||||
)
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
|
@ -103,6 +106,8 @@ class Llama:
|
|||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
print("torch.cuda.set_device")
|
||||
|
||||
# seed must be the same in all processes
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
|
@ -114,7 +119,7 @@ class Llama:
|
|||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
ckpt_dir = model_checkpoint_dir(model_id) # true model id
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
|
|
@ -190,7 +195,7 @@ class Llama:
|
|||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama(model, tokenizer, model_args, llama_model)
|
||||
return Llama(model, tokenizer, model_args, llama_model_id)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue