mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? The goal of this PR is code base modernization. Schema reflection code needed a minor adjustment to handle UnionTypes and collections.abc.AsyncIterator. (Both are preferred for latest Python releases.) Note to reviewers: almost all changes here are automatically generated by pyupgrade. Some additional unused imports were cleaned up. The only change worth of note can be found under `docs/openapi_generator` and `llama_stack/strong_typing/schema.py` where reflection code was updated to deal with "newer" types. Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
214 lines
6.7 KiB
Python
214 lines
6.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
# ruff: noqa: N806
|
|
# pyre-strict
|
|
from typing import Any
|
|
|
|
import fairscale.nn.model_parallel.initialize as fs_init
|
|
import torch
|
|
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
|
from torch import Tensor, nn
|
|
from torch.nn import functional as F
|
|
|
|
from .args import MoEArgs
|
|
from .ffn import FeedForward
|
|
|
|
|
|
class Experts(nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_local_experts: int,
|
|
dim: int,
|
|
hidden_dim: int,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
dtype = torch.get_default_dtype()
|
|
self.num_local_experts = num_local_experts
|
|
self.dim = dim
|
|
divide_factor = fs_init.get_model_parallel_world_size()
|
|
|
|
self.w1: nn.Parameter = nn.Parameter(
|
|
torch.empty(
|
|
num_local_experts,
|
|
dim,
|
|
divide_exact(hidden_dim, divide_factor),
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
|
|
self.w2: nn.Parameter = nn.Parameter(
|
|
torch.empty(
|
|
num_local_experts,
|
|
divide_exact(hidden_dim, divide_factor),
|
|
dim,
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
|
|
self.w3: nn.Parameter = nn.Parameter(
|
|
torch.empty(
|
|
num_local_experts,
|
|
dim,
|
|
divide_exact(hidden_dim, divide_factor),
|
|
dtype=dtype,
|
|
)
|
|
)
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(
|
|
self,
|
|
state_dict: dict[str, Any],
|
|
prefix: str,
|
|
local_metadata: dict[str, Any],
|
|
strict: bool,
|
|
missing_keys: list[str],
|
|
unexpected_keys: list[str],
|
|
error_msgs: list[str],
|
|
) -> None:
|
|
self.prefix = prefix
|
|
if prefix + "moe_w_in_eD_F" in state_dict:
|
|
e = self.num_local_experts
|
|
D = self.dim
|
|
state_dict[prefix + "w1"] = state_dict.pop(prefix + "moe_w_in_eD_F").view(e, D, -1)
|
|
state_dict[prefix + "w2"] = state_dict.pop(prefix + "moe_w_out_eF_D").view(e, -1, D)
|
|
state_dict[prefix + "w3"] = state_dict.pop(prefix + "moe_w_swiglu_eD_F").view(e, D, -1)
|
|
|
|
def forward(
|
|
self,
|
|
routed_in_egD: torch.Tensor, # noqa: N803
|
|
) -> torch.Tensor:
|
|
e = self.num_local_experts
|
|
D = self.dim
|
|
|
|
x_egD = routed_in_egD.view(e, -1, D)
|
|
|
|
out_egD = self.batched_swiglu(x_egD, self.w1, self.w3, self.w2)
|
|
out_egD = out_egD.view(-1, D)
|
|
|
|
return out_egD
|
|
|
|
def batched_swiglu(self, x: Tensor, w1: Tensor, w3: Tensor, w2: Tensor) -> Tensor:
|
|
middle_out_egF = F.silu(torch.bmm(x, w1)) * torch.bmm(x, w3)
|
|
return torch.bmm(middle_out_egF, w2)
|
|
|
|
|
|
class MoE(torch.nn.Module):
|
|
"""
|
|
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
|
Several commonly used annotations include:
|
|
- a: bsz*slen
|
|
- E: number of experts
|
|
- e: number of local experts per ep (n_experts/ep)
|
|
- D: hidden dimension
|
|
- d: D/tp
|
|
- F: model dimension
|
|
- G: number of tokens per expert (a * capacity_factor / E)
|
|
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
|
|
|
Examples:
|
|
x_aD [a, D]
|
|
routed_in_etG_D [et*G, D]
|
|
x_eGD: [e, G, D]
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
hidden_dim: int,
|
|
ffn_dim_multiplier: float,
|
|
multiple_of: int,
|
|
moe_args: MoEArgs,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.moe_args = moe_args
|
|
|
|
hidden_dim_denom: float = 1
|
|
if moe_args.auto_scale_F:
|
|
hidden_dim_denom = moe_args.capacity_factor + 1
|
|
|
|
hidden_dim = int(2 * hidden_dim / 3)
|
|
|
|
# custom dim factor multiplier
|
|
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
|
|
|
if moe_args.auto_scale_F:
|
|
hidden_dim = int(hidden_dim / hidden_dim_denom)
|
|
|
|
hidden_dim += -hidden_dim % multiple_of
|
|
|
|
num_local_experts: int = moe_args.num_experts
|
|
dtype: torch.dtype = torch.get_default_dtype()
|
|
self.experts = Experts(
|
|
num_local_experts,
|
|
dim,
|
|
hidden_dim,
|
|
)
|
|
|
|
self.router_DE: nn.Parameter = nn.Parameter(torch.empty(dim, moe_args.num_experts, dtype=dtype))
|
|
self.shared_expert = FeedForward(dim, hidden_dim, do_reduce=False)
|
|
|
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
|
|
def load_hook(
|
|
self,
|
|
state_dict: dict[str, Any],
|
|
prefix: str,
|
|
local_metadata: dict[str, Any],
|
|
strict: bool,
|
|
missing_keys: list[str],
|
|
unexpected_keys: list[str],
|
|
error_msgs: list[str],
|
|
) -> None:
|
|
if prefix + "w_in_shared_FD.weight" in state_dict:
|
|
state_dict[prefix + "shared_expert.w1.weight"] = state_dict.pop(prefix + "w_in_shared_FD.weight")
|
|
state_dict[prefix + "shared_expert.w3.weight"] = state_dict.pop(prefix + "w_swiglu_FD.weight")
|
|
state_dict[prefix + "shared_expert.w2.weight"] = state_dict.pop(prefix + "w_out_shared_DF.weight")
|
|
|
|
def forward(self, x_bsD: Tensor) -> Tensor: # noqa: N803
|
|
_, slen, D = x_bsD.shape
|
|
x_aD = x_bsD.view(-1, D)
|
|
|
|
a = x_aD.shape[0]
|
|
|
|
router_scores: Tensor = torch.matmul(x_aD, self.router_DE).transpose(0, 1)
|
|
|
|
router_scores_aK, router_indices_aK = torch.topk(router_scores.transpose(0, 1), self.moe_args.top_k, dim=1)
|
|
router_scores = (
|
|
torch.full_like(router_scores.transpose(0, 1), float("-inf"))
|
|
.scatter_(1, router_indices_aK, router_scores_aK)
|
|
.transpose(0, 1)
|
|
)
|
|
router_indices = torch.arange(a, device=x_aD.device).view(1, -1).expand(router_scores.size(0), -1)
|
|
|
|
router_scores = torch.sigmoid(router_scores)
|
|
|
|
routed_in_EG_D: Tensor = torch.gather(
|
|
x_aD,
|
|
dim=0,
|
|
index=router_indices.reshape(-1, 1).expand(-1, D),
|
|
)
|
|
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
|
|
|
out_aD = self.shared_expert(x_aD)
|
|
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
|
|
|
|
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
|
out_aD.scatter_add_(
|
|
dim=0,
|
|
index=router_indices_EG_D,
|
|
src=routed_out_eg_D.view(-1, D),
|
|
)
|
|
out_aD = reduce_from_model_parallel_region(out_aD)
|
|
return out_aD.view(-1, slen, D)
|
|
|
|
|
|
def divide_exact(numerator: int, denominator: int) -> int:
|
|
assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
|
|
return numerator // denominator
|