several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -4,32 +4,43 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import codecs
import io
import json
import os
import sys
import time
from enum import Enum
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..common import TokenResult
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode
from .args import ModelArgs
from .chat_format import (
ChatFormat,
RawContent,
RawMessage,
)
from .chat_format import ChatFormat, RawContent, RawMessage
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
from .model import Transformer
from .tokenizer import Tokenizer
@ -37,12 +48,6 @@ from .tokenizer import Tokenizer
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
class QuantizationMode(str, Enum):
none = "none"
fp8_mixed = "fp8_mixed"
int4_mixed = "int4_mixed"
class Llama4:
@staticmethod
def build(
@ -50,7 +55,7 @@ class Llama4:
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
quantization_mode: Optional[str] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
):
if not torch.distributed.is_initialized():
@ -71,11 +76,9 @@ class Llama4:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@ -92,10 +95,11 @@ class Llama4:
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
print("Model args:\n", model_args.model_dump_json(indent=2))
ckpt_path = checkpoints[get_model_parallel_rank()]
print(f"Loading checkpoint from {ckpt_dir}...")
with open(ckpt_path, "rb") as f:
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
moe_num_experts=model_args.moe_args.num_experts,
)
print("Loaded checkpoint")
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
@ -103,9 +107,9 @@ class Llama4:
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False)
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir)
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
@ -114,7 +118,7 @@ class Llama4:
model = Transformer(model_args)
print("Loading state dict...")
model.load_state_dict(checkpoint, strict=False)
model.load_state_dict(state_dict, strict=False)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
@ -129,7 +133,7 @@ class Llama4:
@torch.inference_mode()
def generate(
self,
llm_input: LLMInput,
llm_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
@ -137,22 +141,20 @@ class Llama4:
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
max_gen_len = self.model.args.max_seq_len - 1
params = self.model.args
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input and get_model_parallel_rank() == 0:
tokens_to_print = list(llm_input.tokens)
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [llm_input.tokens]
if print_model_input:
cprint("Input to model:\n", "yellow")
for inp in llm_inputs:
cprint(self.tokenizer.decode(inp.tokens.tolist()), "grey")
prompt_tokens = [inp.tokens for inp in llm_inputs]
bsz = 1
bsz = len(llm_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -175,24 +177,33 @@ class Llama4:
input_text_mask = tokens != pad_id
if echo:
for i, t in enumerate(llm_input.tokens):
yield TokenResult(
token=t,
text=self.tokenizer.decode([t]),
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
)
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
image_embedding = None
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0:
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
image_mask = image_mask.unsqueeze(-1)
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
image_batch = [llm_input.images]
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
image_embedding = MaskedEmbedding(
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
mask=image_mask,
@ -228,11 +239,21 @@ class Llama4:
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
@ -240,68 +261,47 @@ class Llama4:
def completion(
self,
content: RawContent,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
llm_input = self.formatter.encode_content(content)
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
llm_input=llm_input,
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
if result.token in self.tokenizer.stop_tokens:
break
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages: List[RawMessage],
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> Generator:
llm_input = self.formatter.encode_dialog_prompt(messages)
) -> Generator[List[GenerationResult], None, None]:
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
llm_input=llm_input,
llm_inputs=llm_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
if result.token in self.tokenizer.stop_tokens:
break
yield result
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
):
llm_input = self.formatter.encode_dialog_prompt(messages)
output_tokens = []
for result in self.generate(
llm_input=llm_input,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
):
output_tokens.append(result.token)
return llm_input.tokens, output_tokens
if all(r.finished for r in result):
break
def sample_top_p(probs, p):