# Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import json import os import sys import time from dataclasses import dataclass from pathlib import Path from typing import Generator, List, Optional, TypedDict import torch import torch.nn.functional as F from fairscale.nn.model_parallel.initialize import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_is_initialized, ) from models.llama3_1.api.args import ModelArgs from models.llama3_1.api.chat_format import ChatFormat, ModelInput from models.llama3_1.api.datatypes import Message from models.llama3_1.api.model import Transformer from models.llama3_1.api.tokenizer import Tokenizer from termcolor import cprint @dataclass class TokenResult: token: int text: str logprobs: Optional[List[float]] = None class CompletionPrediction(TypedDict, total=False): generation: str tokens: List[str] # not required logprobs: List[float] # not required class Llama: @staticmethod def build( ckpt_dir: str, tokenizer_path: str, max_seq_len: int, max_batch_size: int, model_parallel_size: Optional[int] = None, seed: int = 1, ) -> "Llama": """ Build a Llama instance by initializing and loading a model checkpoint. Args: ckpt_dir (str): Path to the directory containing checkpoint files. tokenizer_path (str): Path to the tokenizer file. max_seq_len (int): Maximum sequence length for input text. max_batch_size (int): Maximum batch size for inference. model_parallel_size (Optional[int], optional): Number of model parallel processes. If not provided, it's determined from the environment. Defaults to None. Returns: Llama: An instance of the Llama class with the loaded model and tokenizer. Raises: AssertionError: If there are no checkpoint files in the specified directory, or if the model parallel size does not match the number of checkpoint files. Note: This method initializes the distributed process group, sets the device to CUDA, and loads the pre-trained model and tokenizer. """ if not torch.distributed.is_initialized(): torch.distributed.init_process_group("nccl") if not model_parallel_is_initialized(): if model_parallel_size is None: model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) # seed must be the same in all processes torch.manual_seed(seed) if local_rank > 0: sys.stdout = open(os.devnull, "w") start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" assert model_parallel_size == len( checkpoints ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ckpt_path = checkpoints[get_model_parallel_rank()] checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True) with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) # TODO(ashwin): this block is so we can load internal checkpoints without additional # fuss. the final code should _not_ have this blurb if "model" in params: params = params["model"] model_args: ModelArgs = ModelArgs( max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params, ) tokenizer = Tokenizer(model_path=tokenizer_path) assert ( model_args.vocab_size == tokenizer.n_words ), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}" if torch.cuda.is_bf16_supported(): torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) else: torch.set_default_tensor_type(torch.cuda.HalfTensor) model = Transformer(model_args) model.load_state_dict(checkpoint, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer, model_args) def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs): self.args = args self.model = model self.tokenizer = tokenizer self.formatter = ChatFormat(tokenizer) @torch.inference_mode() def generate( self, model_input: ModelInput, max_gen_len: int, temperature: float = 0.6, top_p: float = 0.9, logprobs: bool = False, echo: bool = False, include_stop_token: bool = False, ) -> Generator: params = self.model.params # cprint("Input to model -> " + self.tokenizer.decode(model_input.tokens), "red") prompt_tokens = [model_input.tokens] bsz = 1 assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) min_prompt_len = min(len(t) for t in prompt_tokens) max_prompt_len = max(len(t) for t in prompt_tokens) if max_prompt_len >= params.max_seq_len: cprint( f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red" ) return total_len = min(max_gen_len + max_prompt_len, params.max_seq_len) pad_id = self.tokenizer.pad_id tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") for k, t in enumerate(prompt_tokens): tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") if logprobs: token_logprobs = torch.zeros_like(tokens, dtype=torch.float) prev_pos = 0 eos_reached = torch.tensor([False] * bsz, device="cuda") input_text_mask = tokens != pad_id if min_prompt_len == total_len: # TODO(ashwin): unify this branch with the one below and figure out multimodal crap logits = self.model.forward(tokens, prev_pos) token_logprobs = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens, reduction="none", ignore_index=pad_id, ) stop_tokens = torch.tensor(self.tokenizer.stop_tokens) for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) if temperature > 0: probs = torch.softmax(logits[:, -1] / temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) # only replace token if prompt has already been generated next_token = torch.where( input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token ) tokens[:, cur_pos] = next_token target = tokens[:, prev_pos + 1 : cur_pos + 1] if logprobs: token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy( input=logits.transpose(1, 2), target=tokens[:, prev_pos + 1 : cur_pos + 1], reduction="none", ignore_index=pad_id, ) eos_reached |= (~input_text_mask[:, cur_pos]) & ( torch.isin(next_token, stop_tokens) ) yield TokenResult( token=next_token[0].item(), text=self.tokenizer.decode(next_token.tolist()), logprobs=( token_logprobs[:, prev_pos + 1 : cur_pos + 1][0].tolist() if logprobs else None ), ) prev_pos = cur_pos if all(eos_reached): break def text_completion( self, prompt: str, temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, echo: bool = False, ) -> Generator: if ( max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len ): max_gen_len = self.model.params.max_seq_len - 1 prompt_tokens = self.tokenizer.encode(x, bos=True, eos=False) yield from self.generate( model_input=ModelInput(tokens=prompt_tokens), max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, echo=echo, ) def chat_completion( self, messages: List[Message], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None, logprobs: bool = False, ) -> Generator: if ( max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len ): max_gen_len = self.model.params.max_seq_len - 1 yield from self.generate( model_input=self.formatter.encode_dialog_prompt(messages), max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, logprobs=logprobs, include_stop_token=True, ) def sample_top_p(probs, p): """ Perform top-p (nucleus) sampling on a probability distribution. Args: probs (torch.Tensor): Probability distribution tensor. p (float): Probability threshold for top-p sampling. Returns: torch.Tensor: Sampled token indices. Note: Top-p sampling selects the smallest set of tokens whose cumulative probability mass exceeds the threshold p. The distribution is renormalized based on the selected tokens. """ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(probs_idx, -1, next_token) return next_token