mirror of
https://github.com/meta-llama/llama-stack.git
synced 2026-01-04 10:32:16 +00:00
several fixes
This commit is contained in:
parent
e2e2820c9a
commit
53a8086e37
60 changed files with 1006 additions and 1078 deletions
|
|
@ -100,31 +100,21 @@ class Experts(nn.Module):
|
|||
|
||||
class MoE(torch.nn.Module):
|
||||
"""
|
||||
This EC implementation is modified from the original EC module.
|
||||
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
|
||||
This module supports 3 sharding methods of the experts:
|
||||
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
|
||||
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
|
||||
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
|
||||
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
||||
Several commonly used annotations include:
|
||||
- a: bsz*slen
|
||||
- E: number of experts
|
||||
- e: number of local experts per ep (n_experts/ep)
|
||||
- et: number of local experts per tp (n_experts/tp)
|
||||
- D: hidden dimension
|
||||
- d: D/tp
|
||||
- F: model dimension
|
||||
- f: F/tp (used in column/row-parallel linear)
|
||||
- G: number of tokens per expert (a * capacity_factor / E)
|
||||
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
||||
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
|
||||
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
|
||||
|
||||
Examples:
|
||||
x_aD [a, D]
|
||||
routed_in_etG_D [et*G, D]
|
||||
x_eGGD: [e, GG, D]
|
||||
x_eGD: [e, G, D]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -207,13 +197,13 @@ class MoE(torch.nn.Module):
|
|||
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
||||
|
||||
out_aD = self.shared_expert(x_aD)
|
||||
routed_out_egg_D = self.experts(routed_in_EG_D.detach())
|
||||
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
|
||||
|
||||
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
||||
out_aD.scatter_add_(
|
||||
dim=0,
|
||||
index=router_indices_EG_D,
|
||||
src=routed_out_egg_D.view(-1, D),
|
||||
src=routed_out_eg_D.view(-1, D),
|
||||
)
|
||||
out_aD = reduce_from_model_parallel_region(out_aD)
|
||||
return out_aD.view(-1, slen, D)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue