several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -4,59 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer
@dataclass
class CompletionPrediction:
generation: str
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class ChatPrediction:
generation: RawMessage
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
# TODO: make this completely parallel to the llama4 generation.py file and share common code
# from llama-models also
class Llama3:
@staticmethod
def build(
@ -64,7 +42,7 @@ class Llama3:
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
tokenizer_path: Optional[str] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
@ -101,13 +79,9 @@ class Llama3:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@ -116,40 +90,58 @@ class Llama3:
max_batch_size=max_batch_size,
**params,
)
if tokenizer_path:
tokenizer = Tokenizer(model_path=tokenizer_path)
else:
tokenizer = Tokenizer.get_instance()
tokenizer = Tokenizer.get_instance()
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
)
assert model_args.vocab_size == tokenizer.n_words
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
else:
torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0:
from .multimodal.model import CrossAttentionTransformer
def build_model():
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
else:
model = Transformer(model_args)
return model
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.get_default_dtype())
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
torch.set_default_device(device)
else:
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=True)
model.to(device)
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
@ -158,26 +150,30 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
model_input: LLMInput,
max_gen_len: int,
model_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [model_input.tokens]
for inp in model_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
bsz = 1
bsz = len(model_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -189,18 +185,6 @@ class Llama3:
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
@ -208,23 +192,45 @@ class Llama3:
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
batch_masks=mask,
total_len=total_len,
device=tokens.device,
)
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
for i, t in enumerate(model_input.tokens):
yield TokenResult(
token=t,
text=self.tokenizer.decode([t]),
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
)
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = model_input.vision is None
text_only_inference = all(inp.vision is None for inp in model_inputs)
logits = self.model.forward(
position_ids,
tokens,
@ -271,155 +277,69 @@ class Llama3:
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
def completion(
self,
content: RawContent,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> CompletionPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
tokens = []
token_logprobs = []
decoded_tokens = []
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
generation = self.tokenizer.decode(tokens)
if logprobs:
return CompletionPrediction(
generation=generation,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return CompletionPrediction(generation=generation)
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages: List[RawMessage],
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> ChatPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
tokens = []
token_logprobs = []
decoded_tokens = []
stop_reason = None
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format),
max_gen_len=max_gen_len,
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.formatter.decode_assistant_message(tokens, stop_reason)
if logprobs:
return ChatPrediction(
generation=message,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return ChatPrediction(generation=message)
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> List[int]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
output_tokens = []
model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
input_tokens = model_input.tokens
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
def text_completion_raw(
self,
content: RawContent,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
input_tokens = model_input.tokens
output_tokens = []
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):