mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
update inference config to take model and not model_dir
This commit is contained in:
parent
08c3802f45
commit
039861f1c7
9 changed files with 400 additions and 101 deletions
|
@ -27,11 +27,22 @@ from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
|
|||
from llama_models.llama3_1.api.datatypes import Message
|
||||
from llama_models.llama3_1.api.model import Transformer
|
||||
from llama_models.llama3_1.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_toolchain.common.model_utils import model_local_dir
|
||||
from llama_toolchain.inference.api import QuantizationType
|
||||
|
||||
from .config import CheckpointType, MetaReferenceImplConfig
|
||||
from .config import MetaReferenceImplConfig
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model))
|
||||
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
|
||||
checkpoint_dir = checkpoint_dir / "original"
|
||||
|
||||
assert checkpoint_dir.exists(), f"Could not find checkpoint dir: {checkpoint_dir}"
|
||||
return str(checkpoint_dir)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -51,9 +62,7 @@ class Llama:
|
|||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
checkpoint = config.checkpoint_config.checkpoint
|
||||
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
|
||||
raise NotImplementedError("HuggingFace checkpoints not supported yet")
|
||||
model = resolve_model(config.model)
|
||||
|
||||
if (
|
||||
config.quantization
|
||||
|
@ -67,7 +76,7 @@ class Llama:
|
|||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group("nccl")
|
||||
|
||||
model_parallel_size = checkpoint.model_parallel_size
|
||||
model_parallel_size = model.hardware_requirements.gpu_count
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
|
@ -82,7 +91,8 @@ class Llama:
|
|||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
ckpt_dir = checkpoint.checkpoint_dir
|
||||
ckpt_dir = model_checkpoint_dir(model)
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(
|
||||
|
@ -103,7 +113,9 @@ class Llama:
|
|||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
|
||||
|
||||
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
|
||||
tokenizer = Tokenizer(model_path=tokenizer_path)
|
||||
|
||||
assert (
|
||||
model_args.vocab_size == tokenizer.n_words
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue