From 95263fa1d0bdd1caaad7da8b46de61a454f17bab Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 28 Nov 2023 14:52:50 +0800 Subject: [PATCH] merge operands in topk gating --- internlm/moe/sharded_moe.py | 69 +++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 5d695ac..5d1118b 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -97,23 +97,32 @@ def einsum(rule, a, b): if USE_EINSUM: return torch.einsum(rule, a, b) elif rule == "s,se->se": - # [1, s] * [s, e] + # [s, 1] * [s, e] return a.reshape(a.shape[0], -1) * b + elif rule == "ks,kse->kse": + # [k, s, 1] * [s, e] + return a.reshape(a.shape[0], a.shape[1], -1) * b elif rule == "se,sc->sec": # [s,e,1] * [s,1,c] return a.unsqueeze(2) * b.unsqueeze(1) + elif rule == "kse,ksc->ksec": + # [k,s,e,1] * [k,s,1,c] + return a.unsqueeze(3) * b.unsqueeze(2) elif rule == "se,se->s": # [s,1,e] * [s,e,1] return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) + elif rule == "se,kse->ks": + # [s,1,e] * [k,s,e,1] + return torch.matmul(a.unsqueeze(1), b.unsqueeze(3)).reshape(b.shape[0], -1) elif rule == "sec,sm->ecm": - # [e*c, s] * [s, m] + # [e*c, s] @ [s, m] s = a.shape[0] e = a.shape[1] c = a.shape[2] m = b.shape[1] return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) elif rule == "sec,ecm->sm": - # [s, e*c] * [e*c, m] + # [s, e*c] @ [e*c, m] return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) elif rule == "ks,ksm->sm": k = b.shape[0] @@ -259,47 +268,43 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup indices2_s = torch.argmax(logits_except1, dim=1) mask2 = F.one_hot(indices2_s, num_classes=num_experts) + # merge operands in topk gating to save launch overhead + masks = torch.cat((mask1, mask2), dim=0) + # Compute locations in capacity buffer - locations1 = torch.cumsum(mask1, dim=0) - 1 - locations2 = torch.cumsum(mask2, dim=0) - 1 - # Update 2nd's location by accounting for locations of 1st - locations2 += torch.sum(mask1, dim=0, keepdim=True) + locations = torch.cumsum(masks, dim=0) - 1 + + # reshape (s,e) to (k,s,e) + masks = masks.reshape(-1, gates.shape[0], num_experts) + locations = locations.reshape(-1, gates.shape[0], num_experts) # gating decisions - exp_counts = torch.sum(mask1, dim=0).detach().to("cpu") + exp_counts = torch.sum(masks[0], dim=0).detach().to("cpu") # Compute l_aux me = torch.mean(gates, dim=0) - ce = torch.mean(mask1.type_as(logits), dim=0) + ce = torch.mean(masks[0].type_as(logits), dim=0) l_aux = torch.mean(me * ce) * num_experts * num_experts # Remove locations outside capacity from mask - mask1 *= torch.lt(locations1, capacity) - mask2 *= torch.lt(locations2, capacity) + masks *= torch.lt(locations, capacity) # Store the capacity location for each token - locations1_s = torch.sum(locations1 * mask1, dim=1) - locations2_s = torch.sum(locations2 * mask2, dim=1) + locations_s = torch.sum(locations * masks, dim=2) # Normalize gate probabilities - mask1_float = mask1.type_as(logits) - mask2_float = mask2.type_as(logits) - gates1_s = einsum("se,se->s", gates, mask1_float) - gates2_s = einsum("se,se->s", gates, mask2_float) - denom_s = gates1_s + gates2_s + mask_float = masks.type_as(logits) + gate_s = einsum("se,kse->ks", gates, mask_float) + denom_s = torch.sum(gate_s, dim=0) # Avoid divide-by-zero denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) - gates1_s /= denom_s - gates2_s /= denom_s + gate_s /= denom_s # Calculate combine_weights and dispatch_mask - gates1 = einsum("s,se->se", gates1_s, mask1_float) - gates2 = einsum("s,se->se", gates2_s, mask2_float) - locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits) - locations2_sc = F.one_hot(locations2_s, num_classes=capacity).type_as(logits) - combine1_sec = einsum("se,sc->sec", gates1, locations1_sc) - combine2_sec = einsum("se,sc->sec", gates2, locations2_sc) - combine_weights = combine1_sec + combine2_sec + gate_all = einsum("ks,kse->kse", gate_s, mask_float) + locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits) + combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc) + combine_weights = torch.sum(combine_sec, dim=0) dispatch_mask = combine_weights.bool() return l_aux, combine_weights, dispatch_mask, exp_counts @@ -347,18 +352,12 @@ class TopKGate(Module): self.eval_capacity_factor = eval_capacity_factor self.min_capacity = min_capacity self.noisy_gate_policy = noisy_gate_policy - self.wall_clock_breakdown = False - self.gate_time = 0.0 self.drop_tokens = drop_tokens self.use_rts = use_rts def forward( self, inputs: torch.Tensor, used_token: torch.Tensor = None ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore - - if self.wall_clock_breakdown: - timer("TopKGate").start() - # input jittering if self.noisy_gate_policy == "Jitter" and self.training: inputs = multiplicative_jitter(inputs, device=inputs.device) @@ -380,10 +379,6 @@ class TopKGate(Module): logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity ) - if self.wall_clock_breakdown: - timer("TopKGate").stop() - self.gate_time = timer("TopKGate").elapsed(reset=False) - return gate_output