update inference config to take model and not model_dir

This commit is contained in:
Hardik Shah 2024-08-06 15:02:41 -07:00
parent 08c3802f45
commit 039861f1c7
9 changed files with 400 additions and 101 deletions

View file

@ -27,11 +27,22 @@ from llama_models.llama3_1.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3_1.api.datatypes import Message
from llama_models.llama3_1.api.model import Transformer
from llama_models.llama3_1.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_toolchain.common.model_utils import model_local_dir
from llama_toolchain.inference.api import QuantizationType
from .config import CheckpointType, MetaReferenceImplConfig
from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model))
if not Path(checkpoint_dir / "consolidated.00.pth").exists():
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), f"Could not find checkpoint dir: {checkpoint_dir}"
return str(checkpoint_dir)
@dataclass
@ -51,9 +62,7 @@ class Llama:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
checkpoint = config.checkpoint_config.checkpoint
if checkpoint.checkpoint_type != CheckpointType.pytorch.value:
raise NotImplementedError("HuggingFace checkpoints not supported yet")
model = resolve_model(config.model)
if (
config.quantization
@ -67,7 +76,7 @@ class Llama:
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = checkpoint.model_parallel_size
model_parallel_size = model.hardware_requirements.gpu_count
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
@ -82,7 +91,8 @@ class Llama:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = checkpoint.checkpoint_dir
ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
@ -103,7 +113,9 @@ class Llama:
max_batch_size=config.max_batch_size,
**params,
)
tokenizer = Tokenizer(model_path=checkpoint.tokenizer_path)
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words