From e2299291c42c4d1e29506bbdc366678c8ff4d987 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 9 Apr 2025 11:28:45 -0700 Subject: [PATCH] fix: Mirror llama4 rope scaling fixes, small model simplify (#1917) See: - https://github.com/meta-llama/llama-models/pull/322 - https://github.com/meta-llama/llama-models/pull/320 --- llama_stack/models/llama/llama4/args.py | 13 ++++++ llama_stack/models/llama/llama4/model.py | 51 +++++++++++------------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/llama_stack/models/llama/llama4/args.py b/llama_stack/models/llama/llama4/args.py index 6d7c1d409..dd5f7cbde 100644 --- a/llama_stack/models/llama/llama4/args.py +++ b/llama_stack/models/llama/llama4/args.py @@ -70,6 +70,9 @@ class ModelArgs(BaseModel): attention_chunk_size: Optional[int] = None rope_theta: float = 500000 use_scaled_rope: bool = False + rope_scaling_factor: Optional[float] = None + rope_high_freq_factor: Optional[float] = None + nope_layer_interval: Optional[int] = None # No position encoding in every n layers use_qk_norm: bool = False # Set to True to enable inference-time temperature tuning (useful for very long context) @@ -92,4 +95,14 @@ class ModelArgs(BaseModel): f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})" ) assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})" + + if self.use_scaled_rope: + # NOTE: ideally these values should have come from params.json. However, we have + # shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these + # specific values. + if self.rope_scaling_factor is None: + self.rope_scaling_factor = 16 + if self.rope_high_freq_factor is None: + self.rope_high_freq_factor = 1 + return self diff --git a/llama_stack/models/llama/llama4/model.py b/llama_stack/models/llama/llama4/model.py index 08fac7714..2272b868d 100644 --- a/llama_stack/models/llama/llama4/model.py +++ b/llama_stack/models/llama/llama4/model.py @@ -23,37 +23,25 @@ from .ffn import FeedForward from .moe import MoE +def rmsnorm(x, eps): + def _norm(y): + return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps) + + return _norm(x.float()).type_as(x) + + class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight + return rmsnorm(x, self.eps) * self.weight -class L2Norm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - return self._norm(x.float()).type_as(x) - - -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 +def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float): low_freq_factor = 1 - high_freq_factor = 4 old_context_len = 8192 # original llama3 length low_freq_wavelen = old_context_len / low_freq_factor @@ -72,11 +60,18 @@ def apply_scaling(freqs: torch.Tensor): return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False): +def precompute_freqs_cis( + dim: int, + end: int, + theta: float, + use_scaled: bool, + scale_factor: float, + high_freq_factor: float, +): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device, dtype=torch.float32) if use_scaled: - freqs = apply_scaling(freqs) + freqs = apply_scaling(freqs, scale_factor, high_freq_factor) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis @@ -174,9 +169,7 @@ class Attention(nn.Module): self.head_dim, ) ).cuda() - self.qk_norm = None - if self.use_qk_norm: - self.qk_norm = L2Norm(args.norm_eps) + self.norm_eps = args.norm_eps self._register_load_state_dict_pre_hook(self.load_hook) def load_hook( @@ -220,8 +213,8 @@ class Attention(nn.Module): xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) if self.use_qk_norm: - xq = self.qk_norm(xq) - xk = self.qk_norm(xk) + xq = rmsnorm(xq, self.norm_eps) + xk = rmsnorm(xk, self.norm_eps) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where # the inference-time temperature tuning function is customized to not affect short context @@ -362,6 +355,8 @@ class Transformer(nn.Module): args.max_seq_len * 2, args.rope_theta, args.use_scaled_rope, + args.rope_scaling_factor, + args.rope_high_freq_factor, ) vision_args = self.args.vision_args if vision_args: