refactor: move all llama code to models/llama out of meta reference

This commit is contained in:
Ashwin Bharambe 2025-04-06 16:08:48 -07:00
parent 28e262ecdc
commit e2e2820c9a
29 changed files with 495 additions and 382 deletions

View file

@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from dataclasses import dataclass
from enum import Enum
from typing import Optional
class QuantizationScheme(Enum):
int4_weight_int8_dynamic_activation = "int4_weight_int8_dynamic_activation"
@dataclass
class QuantizationArgs:
scheme: Optional[QuantizationScheme] = None
group_size: Optional[int] = None
spinquant: bool = False
def __init__(self, **kwargs):
for k, v in kwargs.items():
if k == "scheme":
setattr(self, k, QuantizationScheme(v))
else:
if hasattr(self, k):
setattr(self, k, v)
@dataclass
class LoRAArgs:
rank: int
scale: float
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
use_scaled_rope: bool = False
max_batch_size: int = 32
max_seq_len: int = 2048
# vision model params
vision_chunk_size: int = -1 # image resolution for image models
vision_max_num_chunks: int = 4
vision_num_cross_attention_layers: int = -1
quantization_args: Optional[QuantizationArgs] = None
lora_args: Optional[LoRAArgs] = None
def __init__(self, **kwargs):
for k, v in kwargs.items():
if k == "lora_args":
setattr(self, k, LoRAArgs(**v))
elif k == "quantization_args":
setattr(self, k, QuantizationArgs(**v))
else:
if hasattr(self, k):
setattr(self, k, v)
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
assert self.n_kv_heads <= self.n_heads
assert self.n_heads % self.n_kv_heads == 0
assert self.dim % self.n_heads == 0

View file

@ -19,7 +19,7 @@ from typing import Dict, List, Optional, Tuple
from PIL import Image as PIL_Image
from llama_stack.models.llama.datatypes import (
from ..datatypes import (
BuiltinTool,
RawContent,
RawMediaItem,
@ -30,7 +30,6 @@ from llama_stack.models.llama.datatypes import (
ToolCall,
ToolPromptFormat,
)
from .tokenizer import Tokenizer
from .tool_utils import ToolUtils

View file

@ -0,0 +1,447 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import json
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Generator, List, Optional
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
model_parallel_is_initialized,
)
from termcolor import cprint
from ..datatypes import RawContent, RawMessage, StopReason, ToolPromptFormat
from .args import ModelArgs
from .chat_format import ChatFormat, LLMInput
from .model import Transformer
from .tokenizer import Tokenizer
@dataclass
class CompletionPrediction:
generation: str
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class ChatPrediction:
generation: RawMessage
decoded_tokens: Optional[List[str]] = None
logprobs: Optional[List[List[float]]] = None
@dataclass
class TokenResult:
token: int
text: str
logprobs: Optional[List[float]] = None
# TODO: make this completely parallel to the llama4 generation.py file and share common code
# from llama-models also
class Llama3:
@staticmethod
def build(
ckpt_dir: str,
max_seq_len: int,
max_batch_size: int,
world_size: Optional[int] = None,
tokenizer_path: Optional[str] = None,
seed: int = 1,
device: str = "cuda",
):
device = torch.device(device)
if (
device.type == "cuda"
and not torch.cuda.is_available()
or device.type == "xpu"
and not torch.xpu.is_available()
):
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
if not torch.distributed.is_initialized():
if device.type == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if world_size is None:
world_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(world_size)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if device.type == "cuda":
torch.cuda.set_device(local_rank)
elif device.type == "xpu":
torch.xpu.set_device(local_rank)
torch.manual_seed(seed)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
assert world_size == len(checkpoints), (
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
)
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=True)
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)
if tokenizer_path:
tokenizer = Tokenizer(model_path=tokenizer_path)
else:
tokenizer = Tokenizer.get_instance()
assert model_args.vocab_size == tokenizer.n_words
torch.set_default_device(device)
if device.type == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
elif device.type == "xpu":
if torch.xpu.is_bf16_supported():
torch.set_default_dtype(torch.bfloat16)
else:
torch.set_default_dtype(torch.half)
else:
torch.set_default_dtype(torch.half)
if model_args.vision_chunk_size > 0:
from .multimodal.model import CrossAttentionTransformer
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.get_default_dtype())
else:
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=True)
model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return Llama(model, tokenizer, model_args)
def __init__(self, model: Transformer, tokenizer: Tokenizer, args: ModelArgs):
self.args = args
self.model = model
self.tokenizer = tokenizer
self.formatter = ChatFormat(tokenizer)
@torch.inference_mode()
def generate(
self,
model_input: LLMInput,
max_gen_len: int,
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
print_model_input: bool = False,
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> Generator:
params = self.model.params
if print_model_input:
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
cprint(
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
"red",
)
prompt_tokens = [model_input.tokens]
bsz = 1
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
min_prompt_len = min(len(t) for t in prompt_tokens)
max_prompt_len = max(len(t) for t in prompt_tokens)
if max_prompt_len >= params.max_seq_len:
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
return
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
is_vision = not isinstance(self.model, Transformer)
if is_vision:
images = model_input.vision.images if model_input.vision is not None else []
mask = model_input.vision.mask if model_input.vision is not None else []
# the method works for bsz > 1 so add a batch dimension
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
batch_images=[images],
batch_masks=[mask],
total_len=total_len,
)
pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
prev_pos = 0
eos_reached = torch.tensor([False] * bsz)
input_text_mask = tokens != pad_id
if echo:
for i, t in enumerate(model_input.tokens):
yield TokenResult(
token=t,
text=self.tokenizer.decode([t]),
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
)
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
for cur_pos in range(min_prompt_len, total_len):
if is_vision:
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
text_only_inference = model_input.vision is None
logits = self.model.forward(
position_ids,
tokens,
cross_attention_masks,
full_text_row_masked_out_mask,
xattn_caches,
text_only_inference,
)
else:
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if logits_processor is not None:
logits = logits_processor(tokens[:, :cur_pos], logits)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
target = tokens[:, prev_pos + 1 : cur_pos + 1]
if is_vision:
# the logits space (num_classes) is designed to never contain a media_token
# however our input token stream does contain them. we need to nuke them here
# or else the CUDA kernels will crash with an illegal memory access
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
masks = [target.eq(t) for t in vision_tokens]
if len(masks) > 1:
mask = torch.logical_or(*masks)
else:
mask = masks[0]
target[mask] = 0
if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=target,
reduction="none",
ignore_index=pad_id,
)
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
yield TokenResult(
token=next_token[0].item(),
text=self.tokenizer.decode(next_token.tolist()),
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
)
prev_pos = cur_pos
if all(eos_reached):
break
def text_completion(
self,
content: RawContent,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
) -> CompletionPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
tokens = []
token_logprobs = []
decoded_tokens = []
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
generation = self.tokenizer.decode(tokens)
if logprobs:
return CompletionPrediction(
generation=generation,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return CompletionPrediction(generation=generation)
def chat_completion(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
logprobs: bool = False,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
echo: bool = False,
) -> ChatPrediction:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
tokens = []
token_logprobs = []
decoded_tokens = []
stop_reason = None
for result in self.generate(
model_input=self.formatter.encode_dialog_prompt(messages, tool_prompt_format),
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=logprobs,
echo=echo,
):
tokens.append(result.token)
if result.text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
elif result.text == "<|eom_id|>":
stop_reason = StopReason.end_of_message
if logprobs:
decoded_tokens.append(result.text)
token_logprobs.append(result.logprobs)
if stop_reason is None:
stop_reason = StopReason.out_of_tokens
message = self.formatter.decode_assistant_message(tokens, stop_reason)
if logprobs:
return ChatPrediction(
generation=message,
logprobs=token_logprobs,
decoded_tokens=decoded_tokens,
)
return ChatPrediction(generation=message)
def chat_completion_raw(
self,
messages: List[RawMessage],
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
) -> List[int]:
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
output_tokens = []
model_input = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
input_tokens = model_input.tokens
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
def text_completion_raw(
self,
content: RawContent,
temperature: float = 0.6,
top_p: float = 0.9,
max_gen_len: Optional[int] = None,
):
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.params.max_seq_len:
max_gen_len = self.model.params.max_seq_len - 1
model_input = self.formatter.encode_content(content)
input_tokens = model_input.tokens
output_tokens = []
for result in self.generate(
model_input=model_input,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
logprobs=False,
):
output_tokens.append(result.token)
return input_tokens, output_tokens
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token

View file

@ -16,7 +16,7 @@ from typing import List, Optional
from termcolor import colored
from llama_stack.models.llama.datatypes import (
from ..datatypes import (
BuiltinTool,
RawMessage,
StopReason,
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from . import template_data
from .chat_format import ChatFormat
from .prompt_templates import (

View file

@ -0,0 +1,311 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import math
from typing import Optional, Tuple
import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
)
from torch import nn
from .args import ModelArgs
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def apply_scaling(freqs: torch.Tensor) -> torch.Tensor:
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
wavelen = 2 * torch.pi / freqs
new_freqs = torch.where(wavelen > low_freq_wavelen, freqs / scale_factor, freqs)
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
return torch.where(
(wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen),
(1 - smooth) * new_freqs / scale_factor + smooth * new_freqs,
new_freqs,
)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = ColumnParallelLinear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)
self.cache_k = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, args: ModelArgs):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = Attention(args)
self.feed_forward = FeedForward(
dim=args.dim,
hidden_dim=4 * args.dim,
multiple_of=args.multiple_of,
ffn_dim_multiplier=args.ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Transformer(nn.Module):
def __init__(self, params: ModelArgs):
super().__init__()
self.params = params
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.tok_embeddings = VocabParallelEmbedding(params.vocab_size, params.dim, init_method=lambda x: x)
self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(params.dim, params.vocab_size, bias=False, init_method=lambda x: x)
self.freqs_cis = precompute_freqs_cis(
params.dim // params.n_heads,
params.max_seq_len * 2,
params.rope_theta,
params.use_scaled_rope,
)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# https://github.com/pytorch/pytorch/issues/100005
# torch.triu is buggy when the device is mps: filled values are
# nan instead of 0.
if mask.device.type == torch.device("mps").type:
mask = torch.nan_to_num(mask, nan=0.0)
# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([torch.zeros((seqlen, start_pos), device=tokens.device), mask]).type_as(h)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output

View file

@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.

View file

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
import math
from logging import getLogger
import torch
import torch.nn.functional as F
from .utils import get_negative_inf_value, to_2tuple
logger = getLogger()
def resize_local_position_embedding(orig_pos_embed, grid_size):
"""
Resize position embedding for vision encoder.
Original position embedding is [n_tiles * n_tiles + 1, dim]
New position embedding will be [grid_size[0] * grid_size[1] + 1, dim]
"""
new_grid_size = to_2tuple(grid_size)
orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1)))
new_pos_emb_tok, new_pos_emb_img = (
orig_pos_embed[:1],
orig_pos_embed[1:],
)
logger.info(f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}")
new_pos_emb_img = new_pos_emb_img.reshape(1, orig_grid_size[0], orig_grid_size[1], -1).permute(0, 3, 1, 2)
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape(1, new_grid_size[0] * new_grid_size[1], -1)[0]
new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0)
return new_pos_embed
def initialize_global_position_embedding_from_local(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a local position embedding for vision encoder and uses it
to initialize the global position embedding.
Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
pos_embed = pos_and_cls_embed[1:]
cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1)
grid_size = to_2tuple(grid_size)
new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2)
new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1])
new_pos_emb_img = F.interpolate(
new_pos_emb_img,
size=new_grid_size,
mode="bilinear",
align_corners=True,
)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1)
new_pos_emb_img = new_pos_emb_img.view(x_scale, grid_size[0], y_scale, grid_size[1], -1)
new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous()
new_pos_emb_img = new_pos_emb_img.reshape(x_scale, y_scale, grid_size[0] * grid_size[1], -1)
cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1)
pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2)
return pos_and_cls_embed
def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale):
"""
Takes a global position embedding for vision encoder and resizes it to new size.
Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim]
Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim]
Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively.
"""
# first remove cls token
pos_embed = pos_and_cls_embed[:, :, 1:]
cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2)
xs_old, ys_old, ntok, dim = pos_embed.shape
old_grid_size = int(math.sqrt(ntok))
# move to correct form for interpolation
pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim)
pos_embed = pos_embed.unsqueeze(0)
# interpolate
new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale)
pos_embed = pos_embed.permute(0, 3, 1, 2)
pos_embed_resized = F.interpolate(
pos_embed,
size=new_size,
mode="bilinear",
align_corners=True,
)
pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0]
# move it back in place
pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim)
pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous()
pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim)
# interpolate cls token
cls_embed = cls_embed.permute(2, 3, 0, 1)
cls_embed_resized = F.interpolate(
cls_embed,
size=(x_scale, y_scale),
mode="bilinear",
align_corners=True,
)
cls_embed = cls_embed_resized.permute(2, 3, 0, 1)
# add cls token back in
pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2)
return pos_and_cls_embed
def build_encoder_attention_mask(
x: torch.Tensor,
ar: torch.Tensor,
ntok: int,
num_chunks: int,
n_heads: int,
):
"""
Build vision encoder attention mask that omits padding tokens.
"""
masks = []
for arx in ar:
mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype)
mask_i[: arx[0] * arx[1], :ntok] = 0
mask_i = mask_i.view(num_chunks * x.shape[2], -1)
mask_i = mask_i @ mask_i.T * get_negative_inf_value(x.dtype)
mask_i = mask_i.unsqueeze(0)
masks.append(mask_i)
masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1)
return masks
def expand_num_tokens_to_mult8(x):
num_pad_tokens = 8 - (x.shape[-2] % 8)
if num_pad_tokens == 0:
return x, 0
else:
return (
torch.cat(
[
x,
torch.zeros(
(x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]),
dtype=x.dtype,
device=x.device,
),
],
dim=-2,
),
num_pad_tokens,
)
def contract_num_tokens_from_mult8(x, num_pad_tokens):
if num_pad_tokens == 0:
return x
return x[:, :, :-num_pad_tokens]

View file

@ -0,0 +1,408 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import math
from collections import defaultdict
from logging import getLogger
from typing import Any, Optional, Set, Tuple
import torch
import torchvision.transforms as tv
from PIL import Image
from torchvision.transforms import functional as F
IMAGE_RES = 224
logger = getLogger()
class VariableSizeImageTransform(object):
"""
This class accepts images of any size and dynamically resize, pads and chunks it
based on the image aspect ratio and the number of image chunks we allow.
The algorithm will NOT distort the image fit a certain aspect ratio, because
that leads to a significant degradation in image quality.
It can be summarized in 6 steps:
1. Find all possible canvas combinations of max_num_chunks;
2. Find the best canvas to fit the image;
3. Resize without distortion
4. Pad
5. Normalize
6. Chunk
For example, if an input image is of size 300x800, patch_size of 224,
and max_num_chunks = 8, it will find the closest aspect ratio that
is allowed within 8 image chunks, with some restrictions.
In this case, 2:4 = 2 horizontal patches and 4 vertical patches,
giving a total of 8 chunks.
If resize_to_max_canvas, the image will be resized (without distortion),
to the largest possible resolution. In this case, 388:896, and padded to 448:896,
where we maintain the original aspect ratio and pad with zeros value for the rest.
This approach minimizes the amount of padding required for any arbitrary resolution.
However, if limit_upscaling_to_patch_size is set to True,
the upscaling will be limited to the patch size. In the example above,
the image would remain 300x800 (no upscaling), and then padded to 448:896.
The final output will therefore be of shape (8, 3, 224, 224), where 2x4
patches are coming from the resizing and chunking.
"""
def __init__(self, size: int = IMAGE_RES) -> None:
self.size = size
logger.info(f"VariableSizeImageTransform size: {self.size}")
self.to_tensor = tv.ToTensor()
self._mean = (0.48145466, 0.4578275, 0.40821073)
self._std = (0.26862954, 0.26130258, 0.27577711)
self.normalize = tv.Normalize(
mean=self._mean,
std=self._std,
inplace=True,
)
self.resample = tv.InterpolationMode.BILINEAR
@staticmethod
def get_factors(n: int) -> Set[int]:
"""
Calculate all factors of a given number, i.e. a dividor that leaves
no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}.
Args:
n (int): The number to find factors for.
Returns:
set: A set containing all factors of the number.
"""
factors_set = set()
for i in range(1, int(n**0.5) + 1):
if n % i == 0:
factors_set.add(i)
factors_set.add(n // i)
return factors_set
def find_supported_resolutions(self, max_num_chunks: int, patch_size: int) -> torch.Tensor:
"""
Computes all of the allowed resoltuions for a fixed number of chunks
and patch_size. Useful for when dividing an image into chunks.
Args:
max_num_chunks (int): Maximum number of chunks for processing.
patch_size (int): Size of the side of the patch.
Returns:
torch.Tensor: List of possible resolutions as tuples (height, width).
Example:
>>> max_num_chunks = 5
>>> patch_size = 224
>>> find_supported_resolutions(max_num_chunks, patch_size)
tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672),
(672, 224), (224, 448), (448, 224)])
Given max_num_chunks=4, patch_size=224, it will create a dictionary:
{
0.25: [(1, 4)],
1.0: [(2, 2), (1, 1)],
4.0: [(4, 1)],
0.33: [(1, 3)],
3.0: [(3, 1)],
0.5: [(1, 2)],
2.0: [(2, 1)]
}
and return the resolutions multiplied by the patch_size:
[(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)]
"""
asp_dict = defaultdict(list)
for chunk_size in range(max_num_chunks, 0, -1):
_factors = sorted(self.get_factors(chunk_size))
_asp_ratios = [(factor, chunk_size // factor) for factor in _factors]
for height, width in _asp_ratios:
ratio_float = height / width
asp_dict[ratio_float].append((height, width))
# get the resolutions multiplied by the patch_size
possible_resolutions = []
for value in asp_dict.values():
for height, depth in value:
possible_resolutions.append((height * patch_size, depth * patch_size))
return possible_resolutions
@staticmethod
def get_max_res_without_distortion(
image_size: Tuple[int, int],
target_size: Tuple[int, int],
) -> Tuple[int, int]:
"""
Determines the maximum resolution to which an image can be resized to without distorting its
aspect ratio, based on the target resolution.
Args:
image_size (Tuple[int, int]): The original resolution of the image (height, width).
target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width).
Returns:
Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized.
Example:
>>> _get_max_res_without_distortion([200, 300], target_size = [450, 200])
(134, 200)
>>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300])
(450, 338)
"""
original_width, original_height = image_size
target_width, target_height = target_size
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.floor(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.floor(original_width * scale_h), target_width)
return new_width, new_height
def _pad(self, image: Image.Image, target_size) -> Image.Image:
new_width, new_height = target_size
new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore
new_im.paste(image)
return new_im
def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor:
# Split image into number of required tiles (width x height)
num_channels, height, width = image.size()
image = image.view(num_channels, nch, height // nch, ncw, width // ncw)
# Permute dimensions to reorder the axes
image = image.permute(1, 3, 0, 2, 4).contiguous()
# Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2)
image = image.view(ncw * nch, num_channels, height // nch, width // ncw)
return image
def resize_without_distortion(
self,
image: torch.Tensor,
target_size: Tuple[int, int],
max_upscaling_size: Optional[int],
) -> torch.Tensor:
"""
Used to resize an image to target_resolution, without distortion.
If target_size requires upscaling the image, the user can set max_upscaling_size to
limit the upscaling to a maximum size. In this case, since we rescale without distortion,
modifying target_size works as a boundary for the image's largest side.
Args:
resample (str): Resampling method used when resizing images.
Supports "nearest", "nearest_exact", "bilinear", "bicubic".
max_upscaling_size (int): The maximum size to upscale the image to.
If None, there is no limit.
Examples:
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(600, 300) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 600
>>> image_size = (2000, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 100) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = 2000
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
>>> target_size = (1000, 1200)
>>> max_upscaling_size = None
>>> image_size = (400, 200)
>>> resize_without_distortion(image_size, target_size, max_upscaling_size)
(1000, 500) # new_size_without_distortion
"""
image_width, image_height = image.size
image_size = (image_width, image_height)
# If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size
if max_upscaling_size is not None:
new_target_width = min(max(image_width, max_upscaling_size), target_size[0])
new_target_height = min(max(image_height, max_upscaling_size), target_size[1])
target_size = (new_target_width, new_target_height)
# resize to target_size while preserving aspect ratio
new_size_without_distortion = self.get_max_res_without_distortion(image_size, target_size)
image = F.resize(
image,
(new_size_without_distortion[1], new_size_without_distortion[0]),
interpolation=self.resample,
)
return image
def get_best_fit(
self,
image_size: Tuple[int, int],
possible_resolutions: torch.Tensor,
resize_to_max_canvas: bool = False,
) -> Tuple[int, int]:
"""
Determines the best canvas possible from a list of possible resolutions to, without distortion,
resize an image to.
For each possible resolution, calculates the scaling factors for
width and height, and selects the smallest one, which is the limiting side.
E.g. to match the canvas you can upscale height by 2x, and width by 1.5x,
therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5.
If upscaling is possible (any of the scaling factors is greater than 1),
then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True.
If upscaling is not possible, then picks the largest scaling factor <= 1, i.e.
reduce downscaling as much as possible.
If there are multiple resolutions with the same max scale, we pick the one with the lowest area,
to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter
has more padding.
Args:
image_size (Tuple[int, int]): A tuple containing the height and width of the image.
possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each
row represents a possible resolution (height, width).
use_max_upscaling (bool): If True, will return the largest upscaling resolution.
Returns:
List[int]: The best resolution [height, width] for the given image.
Example:
>>> image_size = (200, 300)
>>> possible_resolutions = torch.tensor([[224, 672],
... [672, 224],
... [224, 448],
... [448, 224],
... [224, 224]])
>>> _get_smallest_upscaling_possibility(image_size, possible_resolutions)
[224, 448]
We have:
scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467])
scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200])
scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467])
Only one of the scales > 1:
upscaling_possible = tensor([1.1200, 1.1200])
smallest_rescale = tensor(1.1200)
So we pick the resolution with the smallest smallest area:
areas = tensor([150528, 100352]) # [672, 224], [224, 448]
optimal_canvas = tensor([224, 448])
"""
original_width, original_height = image_size
# get all possible resolutions heights/widths
target_widths, target_heights = (
possible_resolutions[:, 0],
possible_resolutions[:, 1],
)
# get scaling factors to resize the image without distortion
scale_w = target_widths / original_width
scale_h = target_heights / original_height
# get the min scale between width and height (limiting side -> no distortion)
scales = torch.where(scale_w > scale_h, scale_h, scale_w)
# filter only scales that allow upscaling
upscaling_options = scales[scales >= 1]
if len(upscaling_options) > 0:
if resize_to_max_canvas:
selected_scale = torch.max(upscaling_options)
else:
selected_scale = torch.min(upscaling_options)
else:
# no upscaling possible,
# get the minimum downscaling (max scale for scales<1)
downscaling_options = scales[scales < 1]
selected_scale = torch.max(downscaling_options)
# get all resolutions that support this scaling factor,
# e.g. you can upscale to 224x224, 224x448, 224x672 without distortion
chosen_canvas = possible_resolutions[scales == selected_scale]
# if there are multiple resolutions,
# get the one with minimum area to reduce padding
if len(chosen_canvas) > 1:
areas = chosen_canvas[:, 0] * chosen_canvas[:, 1]
optimal_idx = torch.argmin(areas)
optimal_canvas = chosen_canvas[optimal_idx]
else:
optimal_canvas = chosen_canvas[0]
return tuple(optimal_canvas.tolist())
def __call__(
self,
image: Image.Image,
max_num_chunks: int,
normalize_img: bool = True,
resize_to_max_canvas: bool = False,
) -> Tuple[Any, Any]:
"""
Args:
image (PIL.Image): Image to be resized.
max_num_chunks (int): Maximum number of chunks to split the image into.
normalize_img (bool): Whether to normalize the image.
resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size.
If True, picks the canvas the allows the largest resizing without distortion.
If False, downsample as little as possible, including no resizing at all,
but never upsample, unless the image is smaller than the patch size.
"""
assert max_num_chunks > 0
assert isinstance(image, Image.Image), type(image)
w, h = image.size
possible_resolutions = self.find_supported_resolutions(max_num_chunks=max_num_chunks, patch_size=self.size)
possible_resolutions = torch.tensor(possible_resolutions)
best_resolution = self.get_best_fit(
image_size=(w, h),
possible_resolutions=possible_resolutions,
resize_to_max_canvas=resize_to_max_canvas,
)
max_upscaling_size = None if resize_to_max_canvas else self.size
image = self.resize_without_distortion(image, best_resolution, max_upscaling_size)
image = self._pad(image, best_resolution)
image = self.to_tensor(image)
if normalize_img:
image = self.normalize(image)
ratio_w, ratio_h = (
best_resolution[0] // self.size,
best_resolution[1] // self.size,
)
image = self._split(image, ratio_w, ratio_h) # type: ignore
ar = (ratio_h, ratio_w)
return image, ar

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import collections
import torch
def get_negative_inf_value(dtype):
return torch.finfo(dtype).min
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)

View file

@ -0,0 +1,322 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# type: ignore
import os
from typing import Any, Dict, List, Optional, cast
import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import Tensor, nn
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import resolve_model
from ...config import MetaReferenceQuantizedInferenceConfig
from ...datatypes import CheckpointQuantizationFormat
from ...quantize_impls import (
Fp8ScaledWeights,
ffn_swiglu,
load_fp8,
quantize_fp8,
)
from ..args import ModelArgs
from ..model import Transformer, TransformerBlock
log = get_logger(__name__, category="quantization")
def swiglu_wrapper(
self,
x: Tensor,
):
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
return reduce_from_model_parallel_region(out)
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
fp8_activation_scale_ub: Optional[float] = 1200.0,
) -> Transformer:
if config.quantization.type == QuantizationType.bf16.value:
return model
elif config.quantization.type != QuantizationType.fp8.value:
raise ValueError("Only FP8 quantization is supported")
assert config.model is not None, "Model must be specified for quantized inference"
llama_model = resolve_model(config.model)
assert llama_model is not None, f"Model {config.model} not found"
# Move weights to GPU with quantization
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
log.info("Loading fp8 scales...")
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = load_fp8(
param.weight,
fp8_scales[f"{block.layer_id}_feed_forward.{key}_{get_model_parallel_rank()}"],
fp8_activation_scale_ub,
)
else:
log.info("Quantizing fp8 weights from bf16...")
for block in model.layers:
if isinstance(block, TransformerBlock):
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
continue
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
for key in ("w1", "w3", "w2"):
param = getattr(block.feed_forward, key)
param.weight = quantize_fp8(
param.weight,
fp8_activation_scale_ub,
output_device=torch.device("cuda"),
)
for _, parameter in model.named_parameters():
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
"""
Int8DynActInt4WeightLinear with LoRA adaptor.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
device: Device to use.
group_size: Group size for quantization.
precision: Precision of quantization.
scales_precision: Precision of scales.
lora_rank: Rank of LoRA adaptor.
lora_scale: Scale of LoRA adaptor.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias=False,
device=None,
# quantization parameters
group_size: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
) -> None:
super().__init__(
in_features,
out_features,
bias=bias,
device=device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
self.lora_scale: Optional[float] = None
self.adaptor: Optional[nn.Sequential] = None
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
self.adaptor = nn.Sequential()
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(state_dict[prefix + "scales"])
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
if self.adaptor is not None:
adaptor_out = self.adaptor(input_) * self.lora_scale
return module_out + adaptor_out
return module_out
class Int8WeightEmbedding(torch.nn.Embedding):
"""An embedding layer to load int8 weights.
Args:
num_embeddings: Number of embeddings.
embedding_dim: Embedding dimension.
padding_idx: Padding index.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
device=None,
) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
class Int8WeightLinear(torch.nn.Linear):
"""A linear layer to load int8 weights.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
Note that the weights of embedding and output layers are quantized to int8.
"""
device = None
for module_name, module in model.named_children():
if module_name == "output":
quantized_module = Int8WeightLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif module_name == "tok_embeddings":
quantized_module = Int8WeightEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
bias=False,
group_size=group_size,
lora_rank=lora_rank,
lora_scale=lora_scale,
device=device,
)
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(module, group_size, lora_rank, lora_scale)
return model
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
) -> Transformer:
"""Convert the model to int4 quantized model."""
assert model_args.quantization_args is not None, "Quantization args must be specified."
quantization_args = model_args.quantization_args
if quantization_args.scheme is None:
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
)
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError("'group_size' cannot be None in 'quantization_args'. Please specify it.")
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
lora_rank = None
lora_scale = None
else:
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return cast(Transformer, model.to(device))

View file

@ -12,8 +12,7 @@
# the top-level of this source tree.
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from ..datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,
JsonCustomToolGenerator,

View file

@ -16,7 +16,8 @@ import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")