diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE8_sft.py similarity index 99% rename from configs/7B_MoE4_sft.py rename to configs/7B_MoE8_sft.py index 92a93d0..32d56e2 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE8_sft.py @@ -144,6 +144,7 @@ model = dict( num_experts=8, moe_use_residual=False, moe_gate_k=2, + moe_noisy_gate_policy="RSample", ) """ zero1 parallel: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e96d2d9..ec07cd8 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -306,6 +306,8 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_gate_k" not in model: model._add_item("moe_gate_k", 2) + if "moe_noisy_gate_policy" not in model: + model._add_item("moe_noisy_gate_policy", None) # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 5d1118b..e37f1cd 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -248,28 +248,40 @@ def top1gating( return l_aux, combine_weights, dispatch_mask, exp_counts -def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: +def top2gating( + logits: Tensor, + capacity_factor: float, + min_capacity: int, + noisy_gate_policy: Optional[str] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Implements Top2Gating on logits.""" # everything is in fp32 in this function gates = F.softmax(logits, dim=1) + num_experts = int(gates.shape[1]) capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) - # Create a mask for 1st's expert per token - indices1_s = torch.argmax(gates, dim=1) - num_experts = int(gates.shape[1]) - mask1 = F.one_hot(indices1_s, num_classes=num_experts) + # NOTE: here we just add noise on 2nd expert, following + # https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/moe/top2gate.py + if noisy_gate_policy == "RSample": + # Create a mask for 1st's expert per token + indices1_s = torch.argmax(gates, dim=1) + mask1 = F.one_hot(indices1_s, num_classes=num_experts) - # Create a mask for 2nd's expert per token using Gumbel-max trick - # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ - logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) - # Replace top-expert with min value - logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min) - indices2_s = torch.argmax(logits_except1, dim=1) - mask2 = F.one_hot(indices2_s, num_classes=num_experts) - - # merge operands in topk gating to save launch overhead - masks = torch.cat((mask1, mask2), dim=0) + # Create a mask for 2nd's expert per token using Gumbel-max trick + # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ + logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) + # Replace top-expert with min value + logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min) + indices2_s = torch.argmax(logits_except1, dim=1) + mask2 = F.one_hot(indices2_s, num_classes=num_experts) + # merge operands in topk gating to save launch overhead + masks = torch.cat((mask1, mask2), dim=0) + else: + # Create a mask by top-2 experts + indices_s = torch.topk(gates, 2, dim=1).indices + indices_s = indices_s.permute(1, 0).reshape(-1) + masks = F.one_hot(indices_s, num_classes=num_experts) # Compute locations in capacity buffer locations = torch.cumsum(masks, dim=0) - 1 @@ -376,7 +388,10 @@ class TopKGate(Module): else: gate_output = top2gating( - logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity + logits, + self.capacity_factor if self.training else self.eval_capacity_factor, + self.min_capacity, + self.noisy_gate_policy, ) return gate_output