fix: Mirror llama4 rope scaling fixes, small model simplify (#1917)

See:
- https://github.com/meta-llama/llama-models/pull/322
- https://github.com/meta-llama/llama-models/pull/320
This commit is contained in:
Ashwin Bharambe 2025-04-09 11:28:45 -07:00 committed by GitHub
parent 770b38f8b5
commit e2299291c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 36 additions and 28 deletions

View file

@ -23,37 +23,25 @@ from .ffn import FeedForward
from .moe import MoE
def rmsnorm(x, eps):
def _norm(y):
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
return _norm(x.float()).type_as(x)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return rmsnorm(x, self.eps) * self.weight
class L2Norm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x)
def apply_scaling(freqs: torch.Tensor):
# Values obtained from grid search
scale_factor = 8
def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length
low_freq_wavelen = old_context_len / low_freq_factor
@ -72,11 +60,18 @@ def apply_scaling(freqs: torch.Tensor):
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
def precompute_freqs_cis(
dim: int,
end: int,
theta: float,
use_scaled: bool,
scale_factor: float,
high_freq_factor: float,
):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
if use_scaled:
freqs = apply_scaling(freqs)
freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
@ -174,9 +169,7 @@ class Attention(nn.Module):
self.head_dim,
)
).cuda()
self.qk_norm = None
if self.use_qk_norm:
self.qk_norm = L2Norm(args.norm_eps)
self.norm_eps = args.norm_eps
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
@ -220,8 +213,8 @@ class Attention(nn.Module):
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
if self.use_qk_norm:
xq = self.qk_norm(xq)
xk = self.qk_norm(xk)
xq = rmsnorm(xq, self.norm_eps)
xk = rmsnorm(xk, self.norm_eps)
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
# the inference-time temperature tuning function is customized to not affect short context
@ -362,6 +355,8 @@ class Transformer(nn.Module):
args.max_seq_len * 2,
args.rope_theta,
args.use_scaled_rope,
args.rope_scaling_factor,
args.rope_high_freq_factor,
)
vision_args = self.args.vision_args
if vision_args: