llama-stack-mirror/llama_stack/providers/impls/meta_reference/inference/generation.py
Xi Yan 4d5f7459aa
[bugfix] Fix logprobs on meta-reference impl (#213)
* fix log probs

* add back LogProbsConfig

* error handling

* bugfix
2024-10-07 19:42:39 -07:00

388 lines
14 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.api.chat_format import ChatFormat, ModelInput
from llama_models.llama3.api.datatypes import (
InterleavedTextMedia,
Message,
ToolPromptFormat,
)
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.reference_impl.model import Transformer
from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from termcolor import cprint
from llama_stack.apis.inference import QuantizationType
from llama_stack.distribution.utils.model_utils import model_local_dir
from .config import MetaReferenceImplConfig
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [Path(checkpoint_dir / f"consolidated.{ext}") for ext in ["pth", "00.pth"]]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
)
return str(checkpoint_dir)
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
class Llama:
@staticmethod
def build(config: MetaReferenceImplConfig):
"""
Build a Llama instance by initializing and loading a model checkpoint.
Note:
This method initializes the distributed process group, sets the device to CUDA,
and loads the pre-trained model and tokenizer.
"""
model = resolve_model(config.model)
if (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
):
from .quantization.loader import is_fbgemm_available
if not is_fbgemm_available():
raise ImportError("fbgemm-gpu is required for FP8 quantization")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
model_parallel_size = config.model_parallel_size
if not model_parallel_is_initialized():
initialize_model_parallel(model_parallel_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
if config.torch_seed is not None:
torch.manual_seed(config.torch_seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
ckpt_dir = model_checkpoint_dir(model)
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert model_parallel_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
if "model" in params:
params = params["model"]
model_args: ModelArgs = ModelArgs(
max_seq_len=config.max_seq_len,
max_batch_size=config.max_batch_size,
**params,
)
tokenizer_path = os.path.join(ckpt_dir, "tokenizer.model")
tokenizer = Tokenizer(model_path=tokenizer_path)
assert (
model_args.vocab_size == tokenizer.n_words
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
fp8 = (
config.quantization
and config.quantization.type == QuantizationType.fp8.value
)
if fp8:
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_input: ModelInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
include_stop_token: bool = False,
) -> Generator:
params = self.model.params
# input_tokens = [
# self.formatter.vision_token if t == 128256 else t
# for t in model_input.tokens
# ]
# cprint("Input to model -> " + self.tokenizer.decode(input_tokens), "red")
prompt_tokens = [model_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(
f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red"
)
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = isinstance(self.model, CrossAttentionTransformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = (
self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
logits = self.model.forward(tokens, prev_pos)
token_logprobs = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens,
reduction="none",
ignore_index=pad_id,
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(
prev_pos, cur_pos, dtype=torch.long, device="cuda"
)
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1],
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens)
)
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(
token_logprobs[:, cur_pos : cur_pos + 1][0].tolist()
if logprobs
else None
),
)
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
self,
content: InterleavedTextMedia,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
yield from self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
)
def chat_completion(
self,
messages: List[Message],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> Generator:
if (
max_gen_len is None
or max_gen_len == 0
or max_gen_len >= self.model.params.max_seq_len
):
max_gen_len = self.model.params.max_seq_len - 1
yield from self.generate(
model_input=self.formatter.encode_dialog_prompt(
messages,
tool_prompt_format,
),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
include_stop_token=True,
)
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token