feat: introduce llama4 support (#1877)

As title says. Details in README, elsewhere.
This commit is contained in:
Ashwin Bharambe 2025-04-05 11:53:35 -07:00 committed by GitHub
parent 23a99a4b22
commit b8f1561956
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 205222 additions and 6439 deletions

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from typing import Any, Dict, List
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from torch import nn
from torch.nn import functional as F
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
do_reduce: bool = True,
):
super().__init__()
self.do_reduce = do_reduce
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
if prefix + "mlp.fc1_weight" in state_dict:
w1, w3 = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2, dim=0)
state_dict[prefix + "w1.weight"] = w1
state_dict[prefix + "w3.weight"] = w3
state_dict[prefix + "w2.weight"] = state_dict.pop(prefix + "mlp.fc2_weight")
def forward(self, x):
x = F.silu(F.linear(x, self.w1.weight)) * F.linear(x, self.w3.weight)
out = F.linear(x, self.w2.weight)
if self.do_reduce:
return reduce_from_model_parallel_region(out)
return out