several fixes

This commit is contained in:
Ashwin Bharambe 2025-04-07 10:31:20 -07:00
parent e2e2820c9a
commit 53a8086e37
60 changed files with 1006 additions and 1078 deletions

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Optional

View file

@ -4,13 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import io
import json
import uuid

View file

@ -4,59 +4,37 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat
from ..checkpoint import maybe_reshard_state_dict
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
from .multimodal.model import CrossAttentionTransformer
from .tokenizer import Tokenizer
@dataclass
class CompletionPrediction:
generation: str
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class ChatPrediction:
generation: RawMessage
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
# TODO: make this completely parallel to the llama4 generation.py file and share common code
# from llama-models also
class Llama3:
@staticmethod
def build(
@ -64,7 +42,7 @@ class Llama3:
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
tokenizer_path: Optional[str] = None,
quantization_mode: Optional[QuantizationMode] = None,
seed: int = 1,
device: str = "cuda",
):
@ -101,13 +79,9 @@ class Llama3:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
@ -116,40 +90,58 @@ class Llama3:
max_batch_size=max_batch_size,
**params,
)
if tokenizer_path:
tokenizer = Tokenizer(model_path=tokenizer_path)
else:
tokenizer = Tokenizer.get_instance()
tokenizer = Tokenizer.get_instance()
state_dict = maybe_reshard_state_dict(
ckpt_paths,
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
)
assert model_args.vocab_size == tokenizer.n_words
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
else:
torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0:
from .multimodal.model import CrossAttentionTransformer
def build_model():
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
else:
model = Transformer(model_args)
return model
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.get_default_dtype())
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
from .quantization.loader import convert_to_quantized_model
torch.set_default_tensor_type(torch.BFloat16Tensor)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=False)
print("Done...")
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
torch.set_default_device(device)
else:
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=True)
model.to(device)
print(f"Setting default device to {device}")
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
model = build_model()
print("Loading state dict...")
model.load_state_dict(state_dict, strict=True)
model.to(device)
print("Done...")
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
return Llama3(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
@ -158,26 +150,30 @@ class Llama3:
@torch.inference_mode()
def generate(
self,
model_input: LLMInput,
max_gen_len: int,
model_inputs: List[LLMInput],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
) -> Generator[List[GenerationResult], None, None]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
max_gen_len = self.args.max_seq_len - 1
params = self.model.params
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
if print_model_input:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [model_input.tokens]
for inp in model_inputs:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [inp.tokens for inp in model_inputs]
bsz = 1
bsz = len(model_inputs)
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
@ -189,18 +185,6 @@ class Llama3:
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
@ -208,23 +192,45 @@ class Llama3:
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=images,
batch_masks=mask,
total_len=total_len,
device=tokens.device,
)
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
for i, t in enumerate(model_input.tokens):
yield TokenResult(
token=t,
text=self.tokenizer.decode([t]),
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
)
for i in range(max_prompt_len):
results = []
for j, t in enumerate(tokens[:, i]):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="input",
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
batch_idx=j,
finished=False,
ignore_token=t.item() == pad_id,
)
)
yield results
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
prev_pos = 0
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = model_input.vision is None
text_only_inference = all(inp.vision is None for inp in model_inputs)
logits = self.model.forward(
position_ids,
tokens,
@ -271,155 +277,69 @@ class Llama3:
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
results = []
for idx, t in enumerate(next_token):
results.append(
GenerationResult(
token=t.item(),
text=self.tokenizer.decode([t.item()]),
source="output",
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
batch_idx=idx,
finished=eos_reached[idx],
ignore_token=cur_pos < len(prompt_tokens[idx]),
)
)
yield results
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
def completion(
self,
content: RawContent,
contents: List[RawContent],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> CompletionPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
tokens = []
token_logprobs = []
decoded_tokens = []
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_content(c) for c in contents]
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
generation = self.tokenizer.decode(tokens)
if logprobs:
return CompletionPrediction(
generation=generation,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return CompletionPrediction(generation=generation)
yield result
if all(r.finished for r in result):
break
def chat_completion(
self,
messages: List[RawMessage],
messages_batch: List[List[RawMessage]],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> ChatPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
tokens = []
token_logprobs = []
decoded_tokens = []
stop_reason = None
) -> Generator[List[GenerationResult], None, None]:
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
for result in self.generate(
model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format),
max_gen_len=max_gen_len,
model_inputs=model_inputs,
temperature=temperature,
top_p=top_p,
max_gen_len=max_gen_len,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.formatter.decode_assistant_message(tokens, stop_reason)
if logprobs:
return ChatPrediction(
generation=message,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return ChatPrediction(generation=message)
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> List[int]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
output_tokens = []
model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
input_tokens = model_input.tokens
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
def text_completion_raw(
self,
content: RawContent,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
input_tokens = model_input.tokens
output_tokens = []
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
yield result
if all(r.finished for r in result):
break
def sample_top_p(probs, p):

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from typing import Optional, Tuple
@ -29,6 +19,10 @@ from torch import nn
from .args import ModelArgs
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
# dependencies. These dependencies are not part of the default dependencies
# (requirements.txt) of the `llama-models` package.
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
@ -111,9 +105,9 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
world_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
import logging
import math
from functools import partial
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
n_heads,
):
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
world_size = fs_init.get_model_parallel_world_size()
qkvo_replication = 1
if model_parallel_size > 16:
qkvo_replication = model_parallel_size // 8
if world_size > 16:
qkvo_replication = world_size // 8
self.n_kv_heads = n_heads
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
self.n_local_heads = n_heads * qkvo_replication // world_size
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads
@ -536,16 +526,16 @@ class Attention(nn.Module):
cache_v (torch.Tensor): Cached values for attention.
"""
super().__init__()
model_parallel_size = fs_init.get_model_parallel_world_size()
world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
if model_parallel_size > 8:
replication_factor = model_parallel_size // MP_SCALE
if world_size > 8:
replication_factor = world_size // MP_SCALE
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_kv_heads *= replication_factor
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_local_heads = args.n_heads // world_size
self.n_local_kv_heads = self.n_kv_heads // world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.max_seq_len = args.max_seq_len
@ -587,13 +577,11 @@ class Attention(nn.Module):
self.n_local_kv_heads,
self.head_dim,
)
device = next(self.parameters()).device
self.register_buffer(
"key_cache",
torch.zeros(
cache_shape,
dtype=dtype,
device=device,
),
persistent=False,
)
@ -602,7 +590,6 @@ class Attention(nn.Module):
torch.zeros(
cache_shape,
dtype=dtype,
device=device,
),
persistent=False,
)
@ -614,6 +601,9 @@ class Attention(nn.Module):
freqs_cis: torch.Tensor,
position_ids: torch.LongTensor,
):
self.key_cache = self.key_cache.to(x.device)
self.value_cache = self.value_cache.to(x.device)
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
bs, slen, _ = xq.shape
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
norm_eps: float,
):
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.world_size = fs_init.get_model_parallel_world_size()
replication_factor = 1
if self.model_parallel_size > 8:
replication_factor = self.model_parallel_size // MP_SCALE
if self.world_size > 8:
replication_factor = self.world_size // MP_SCALE
n_kv_heads *= replication_factor
assert n_heads % n_kv_heads == 0
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
# trunk LLM (i.e., group query attention) -- @dubeya
# local heads
assert self.n_heads % self.n_kv_heads == 0
assert self.n_heads % self.model_parallel_size == 0
assert self.n_kv_heads % self.model_parallel_size == 0
self.n_local_heads = self.n_heads // self.model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
assert self.n_heads % self.world_size == 0
assert self.n_kv_heads % self.world_size == 0
self.n_local_heads = self.n_heads // self.world_size
self.n_local_kv_heads = self.n_kv_heads // self.world_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
self.image_res = args.vision_chunk_size
self.max_num_chunks = args.vision_max_num_chunks
if return_intermediate is not None:
return_intermediate = [int(level) for level in return_intermediate.split(",")]
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
self.patch_size = 14
self.vision_encoder = VisionEncoder(
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
self.model_parallel_size = fs_init.get_model_parallel_world_size()
self.world_size = fs_init.get_model_parallel_world_size()
assert args.vocab_size > 0
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
assert self.vocab_size % self.model_parallel_size == 0
self.n_local_kv_heads = self.n_kv_heads // self.world_size
assert self.vocab_size % self.world_size == 0
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
self.pos_embeddings = None
# final norm layer (not necessary for post-norm)
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
text_only_inference: bool = False,
):
assert self.cache_is_setup, "Please set up cache before calling forward"
self.mask_cache = self.mask_cache.to(h.device)
self.freqs_cis = self.freqs_cis.to(h.device)
mask = self.mask_cache.index_select(2, position_ids)
freqs_cis = self.freqs_cis.index_select(0, position_ids)
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
output = gather_from_tensor_model_parallel_region(output)
return output.float()
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
# Set up the text kv caches
device = next(self.parameters()).device
ones = torch.ones(
(self.max_seq_len, self.max_seq_len),
dtype=torch.bool,
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
return (
cross_attention_masks.to(device=text_device, dtype=text_dtype),
full_text_row_masked_out_mask,
full_text_row_masked_out_mask.to(device=text_device),
)
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
max_num_chunks=args.vision_max_num_chunks,
)
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, dtype)
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
self.text_model.setup_cache(max_batch_size, device, dtype)
def compute_vision_tokens_masks(
self,
batch_images: List[List[PIL_Image.Image]],
batch_masks: List[List[List[int]]],
total_len: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
skip_vision_encoder = False
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
image_res=self.params.vision_chunk_size,
max_num_images=max_num_images,
)
stacked_images = stacked_images.to(device=device)
if skip_vision_encoder:
vision_tokens = torch.zeros(
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
),
)
else:
vision_tokens = self.vision_model(stacked_images, aspect_ratios)
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
xattn_caches = torch.stack(

View file

@ -15,7 +15,7 @@ import textwrap
from datetime import datetime
from typing import Any, List, Optional
from llama_stack.models.llama.datatypes import (
from llama_stack.apis.inference import (
BuiltinTool,
ToolDefinition,
ToolParamDefinition,

View file

@ -4,9 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from ...config import MetaReferenceQuantizedInferenceConfig
from ...datatypes import CheckpointQuantizationFormat
from ...datatypes import QuantizationMode
from ...quantize_impls import (
Fp8ScaledWeights,
ffn_swiglu,
load_fp8,
quantize_fp8,
)
from ..args import ModelArgs
from ..model import Transformer, TransformerBlock
log = get_logger(__name__, category="quantization")
from ..multimodal.model import CrossAttentionTransformer
def swiglu_wrapper(
@ -44,30 +34,34 @@ def swiglu_wrapper(
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
quantization_mode: Optional[str] = None,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
if quantization_mode == QuantizationMode.fp8_mixed:
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
elif quantization_mode == QuantizationMode.int4_mixed:
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
else:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
device: Optional[torch.device] = None,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
assert config.model is not None, "Model must be specified for quantized inference"
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
if os.path.isfile(fp8_scales_path):
print("Loading fp8 scales...")
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model(
fp8_activation_scale_ub,
)
else:
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
print("Quantizing fp8 weights from bf16...")
for _, block in model.named_modules():
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model(
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
output_device=device,
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
parameter.data = parameter.to(device=device)
return model
@ -290,11 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
) -> Transformer:
model: Transformer | CrossAttentionTransformer,
checkpoint_dir: str,
device: Optional[torch.device] = None,
) -> Transformer | CrossAttentionTransformer:
"""Convert the model to int4 quantized model."""
model_args = model.params
assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args
if quantization_args.scheme is None:
@ -318,5 +313,4 @@ def convert_to_int4_quantized_model(
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return cast(Transformer, model.to(device))
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from logging import getLogger
from pathlib import Path