fix: Mirror llama4 rope scaling fixes, small model simplify (#1917)

See:
- https://github.com/meta-llama/llama-models/pull/322
- https://github.com/meta-llama/llama-models/pull/320
This commit is contained in:
Ashwin Bharambe 2025-04-09 11:28:45 -07:00 committed by GitHub
parent 770b38f8b5
commit e2299291c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 28 deletions

View file

@ -70,6 +70,9 @@ class ModelArgs(BaseModel):
attention_chunk_size: Optional[int] = None attention_chunk_size: Optional[int] = None
rope_theta: float = 500000 rope_theta: float = 500000
use_scaled_rope: bool = False use_scaled_rope: bool = False
rope_scaling_factor: Optional[float] = None
rope_high_freq_factor: Optional[float] = None
nope_layer_interval: Optional[int] = None # No position encoding in every n layers nope_layer_interval: Optional[int] = None # No position encoding in every n layers
use_qk_norm: bool = False use_qk_norm: bool = False
# Set to True to enable inference-time temperature tuning (useful for very long context) # Set to True to enable inference-time temperature tuning (useful for very long context)
@ -92,4 +95,14 @@ class ModelArgs(BaseModel):
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})" f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
) )
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})" assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
if self.use_scaled_rope:
# NOTE: ideally these values should have come from params.json. However, we have
# shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these
# specific values.
if self.rope_scaling_factor is None:
self.rope_scaling_factor = 16
if self.rope_high_freq_factor is None:
self.rope_high_freq_factor = 1
return self return self

View file

@ -23,37 +23,25 @@ from .ffn import FeedForward
from .moe import MoE from .moe import MoE
def rmsnorm(x, eps):
def _norm(y):
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
return _norm(x.float()).type_as(x)
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.ones(dim)) self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): def forward(self, x):
output = self._norm(x.float()).type_as(x) return rmsnorm(x, self.eps) * self.weight
return output * self.weight
class L2Norm(torch.nn.Module): def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1 low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor low_freq_wavelen = old_context_len / low_freq_factor
@ -72,11 +60,18 @@ def apply_scaling(freqs: torch.Tensor):
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False): def precompute_freqs_cis(
dim: int,
end: int,
theta: float,
use_scaled: bool,
scale_factor: float,
high_freq_factor: float,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32) t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled: if use_scaled:
freqs = apply_scaling(freqs) freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
freqs = torch.outer(t, freqs) freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis return freqs_cis
@ -174,9 +169,7 @@ class Attention(nn.Module):
self.head_dim, self.head_dim,
) )
).cuda() ).cuda()
self.qk_norm = None self.norm_eps = args.norm_eps
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)
self._register_load_state_dict_pre_hook(self.load_hook) self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook( def load_hook(
@ -220,8 +213,8 @@ class Attention(nn.Module):
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if self.use_qk_norm: if self.use_qk_norm:
xq = self.qk_norm(xq) xq = rmsnorm(xq, self.norm_eps)
xk = self.qk_norm(xk) xk = rmsnorm(xk, self.norm_eps)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context # the inference-time temperature tuning function is customized to not affect short context
@ -362,6 +355,8 @@ class Transformer(nn.Module):
args.max_seq_len * 2, args.max_seq_len * 2,
args.rope_theta, args.rope_theta,
args.use_scaled_rope, args.use_scaled_rope,
args.rope_scaling_factor,
args.rope_high_freq_factor,
) )
vision_args = self.args.vision_args vision_args = self.args.vision_args
if vision_args: if vision_args: