mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Mirror llama4 rope scaling fixes, small model simplify (#1917)
See: - https://github.com/meta-llama/llama-models/pull/322 - https://github.com/meta-llama/llama-models/pull/320
This commit is contained in:
parent
770b38f8b5
commit
e2299291c4
2 changed files with 36 additions and 28 deletions
|
@ -70,6 +70,9 @@ class ModelArgs(BaseModel):
|
|||
attention_chunk_size: Optional[int] = None
|
||||
rope_theta: float = 500000
|
||||
use_scaled_rope: bool = False
|
||||
rope_scaling_factor: Optional[float] = None
|
||||
rope_high_freq_factor: Optional[float] = None
|
||||
|
||||
nope_layer_interval: Optional[int] = None # No position encoding in every n layers
|
||||
use_qk_norm: bool = False
|
||||
# Set to True to enable inference-time temperature tuning (useful for very long context)
|
||||
|
@ -92,4 +95,14 @@ class ModelArgs(BaseModel):
|
|||
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
|
||||
)
|
||||
assert self.dim % self.n_heads == 0, f"dim ({self.dim}) must be divisible by n_heads ({self.n_heads})"
|
||||
|
||||
if self.use_scaled_rope:
|
||||
# NOTE: ideally these values should have come from params.json. However, we have
|
||||
# shipped the models everywhere. Only Llama-4-Scout uses scaled rope and needs these
|
||||
# specific values.
|
||||
if self.rope_scaling_factor is None:
|
||||
self.rope_scaling_factor = 16
|
||||
if self.rope_high_freq_factor is None:
|
||||
self.rope_high_freq_factor = 1
|
||||
|
||||
return self
|
||||
|
|
|
@ -23,37 +23,25 @@ from .ffn import FeedForward
|
|||
from .moe import MoE
|
||||
|
||||
|
||||
def rmsnorm(x, eps):
|
||||
def _norm(y):
|
||||
return y * torch.rsqrt(y.pow(2).mean(-1, keepdim=True) + eps)
|
||||
|
||||
return _norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
return rmsnorm(x, self.eps) * self.weight
|
||||
|
||||
|
||||
class L2Norm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
return self._norm(x.float()).type_as(x)
|
||||
|
||||
|
||||
def apply_scaling(freqs: torch.Tensor):
|
||||
# Values obtained from grid search
|
||||
scale_factor = 8
|
||||
def apply_scaling(freqs: torch.Tensor, scale_factor: float, high_freq_factor: float):
|
||||
low_freq_factor = 1
|
||||
high_freq_factor = 4
|
||||
old_context_len = 8192 # original llama3 length
|
||||
|
||||
low_freq_wavelen = old_context_len / low_freq_factor
|
||||
|
@ -72,11 +60,18 @@ def apply_scaling(freqs: torch.Tensor):
|
|||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False):
|
||||
def precompute_freqs_cis(
|
||||
dim: int,
|
||||
end: int,
|
||||
theta: float,
|
||||
use_scaled: bool,
|
||||
scale_factor: float,
|
||||
high_freq_factor: float,
|
||||
):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
||||
if use_scaled:
|
||||
freqs = apply_scaling(freqs)
|
||||
freqs = apply_scaling(freqs, scale_factor, high_freq_factor)
|
||||
freqs = torch.outer(t, freqs)
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
return freqs_cis
|
||||
|
@ -174,9 +169,7 @@ class Attention(nn.Module):
|
|||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
self.qk_norm = None
|
||||
if self.use_qk_norm:
|
||||
self.qk_norm = L2Norm(args.norm_eps)
|
||||
self.norm_eps = args.norm_eps
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
|
@ -220,8 +213,8 @@ class Attention(nn.Module):
|
|||
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
if self.use_qk_norm:
|
||||
xq = self.qk_norm(xq)
|
||||
xk = self.qk_norm(xk)
|
||||
xq = rmsnorm(xq, self.norm_eps)
|
||||
xk = rmsnorm(xk, self.norm_eps)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
|
@ -362,6 +355,8 @@ class Transformer(nn.Module):
|
|||
args.max_seq_len * 2,
|
||||
args.rope_theta,
|
||||
args.use_scaled_rope,
|
||||
args.rope_scaling_factor,
|
||||
args.rope_high_freq_factor,
|
||||
)
|
||||
vision_args = self.args.vision_args
|
||||
if vision_args:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue