mirror of https://github.com/InternLM/InternLM
merge operands in topk gating
parent
06e8301861
commit
95263fa1d0
|
@ -97,23 +97,32 @@ def einsum(rule, a, b):
|
|||
if USE_EINSUM:
|
||||
return torch.einsum(rule, a, b)
|
||||
elif rule == "s,se->se":
|
||||
# [1, s] * [s, e]
|
||||
# [s, 1] * [s, e]
|
||||
return a.reshape(a.shape[0], -1) * b
|
||||
elif rule == "ks,kse->kse":
|
||||
# [k, s, 1] * [s, e]
|
||||
return a.reshape(a.shape[0], a.shape[1], -1) * b
|
||||
elif rule == "se,sc->sec":
|
||||
# [s,e,1] * [s,1,c]
|
||||
return a.unsqueeze(2) * b.unsqueeze(1)
|
||||
elif rule == "kse,ksc->ksec":
|
||||
# [k,s,e,1] * [k,s,1,c]
|
||||
return a.unsqueeze(3) * b.unsqueeze(2)
|
||||
elif rule == "se,se->s":
|
||||
# [s,1,e] * [s,e,1]
|
||||
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
|
||||
elif rule == "se,kse->ks":
|
||||
# [s,1,e] * [k,s,e,1]
|
||||
return torch.matmul(a.unsqueeze(1), b.unsqueeze(3)).reshape(b.shape[0], -1)
|
||||
elif rule == "sec,sm->ecm":
|
||||
# [e*c, s] * [s, m]
|
||||
# [e*c, s] @ [s, m]
|
||||
s = a.shape[0]
|
||||
e = a.shape[1]
|
||||
c = a.shape[2]
|
||||
m = b.shape[1]
|
||||
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
|
||||
elif rule == "sec,ecm->sm":
|
||||
# [s, e*c] * [e*c, m]
|
||||
# [s, e*c] @ [e*c, m]
|
||||
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
|
||||
elif rule == "ks,ksm->sm":
|
||||
k = b.shape[0]
|
||||
|
@ -259,47 +268,43 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
|
|||
indices2_s = torch.argmax(logits_except1, dim=1)
|
||||
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
||||
|
||||
# merge operands in topk gating to save launch overhead
|
||||
masks = torch.cat((mask1, mask2), dim=0)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
locations1 = torch.cumsum(mask1, dim=0) - 1
|
||||
locations2 = torch.cumsum(mask2, dim=0) - 1
|
||||
# Update 2nd's location by accounting for locations of 1st
|
||||
locations2 += torch.sum(mask1, dim=0, keepdim=True)
|
||||
locations = torch.cumsum(masks, dim=0) - 1
|
||||
|
||||
# reshape (s,e) to (k,s,e)
|
||||
masks = masks.reshape(-1, gates.shape[0], num_experts)
|
||||
locations = locations.reshape(-1, gates.shape[0], num_experts)
|
||||
|
||||
# gating decisions
|
||||
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
|
||||
exp_counts = torch.sum(masks[0], dim=0).detach().to("cpu")
|
||||
|
||||
# Compute l_aux
|
||||
me = torch.mean(gates, dim=0)
|
||||
ce = torch.mean(mask1.type_as(logits), dim=0)
|
||||
ce = torch.mean(masks[0].type_as(logits), dim=0)
|
||||
l_aux = torch.mean(me * ce) * num_experts * num_experts
|
||||
|
||||
# Remove locations outside capacity from mask
|
||||
mask1 *= torch.lt(locations1, capacity)
|
||||
mask2 *= torch.lt(locations2, capacity)
|
||||
masks *= torch.lt(locations, capacity)
|
||||
|
||||
# Store the capacity location for each token
|
||||
locations1_s = torch.sum(locations1 * mask1, dim=1)
|
||||
locations2_s = torch.sum(locations2 * mask2, dim=1)
|
||||
locations_s = torch.sum(locations * masks, dim=2)
|
||||
|
||||
# Normalize gate probabilities
|
||||
mask1_float = mask1.type_as(logits)
|
||||
mask2_float = mask2.type_as(logits)
|
||||
gates1_s = einsum("se,se->s", gates, mask1_float)
|
||||
gates2_s = einsum("se,se->s", gates, mask2_float)
|
||||
denom_s = gates1_s + gates2_s
|
||||
mask_float = masks.type_as(logits)
|
||||
gate_s = einsum("se,kse->ks", gates, mask_float)
|
||||
denom_s = torch.sum(gate_s, dim=0)
|
||||
# Avoid divide-by-zero
|
||||
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
|
||||
gates1_s /= denom_s
|
||||
gates2_s /= denom_s
|
||||
gate_s /= denom_s
|
||||
|
||||
# Calculate combine_weights and dispatch_mask
|
||||
gates1 = einsum("s,se->se", gates1_s, mask1_float)
|
||||
gates2 = einsum("s,se->se", gates2_s, mask2_float)
|
||||
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
|
||||
locations2_sc = F.one_hot(locations2_s, num_classes=capacity).type_as(logits)
|
||||
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
|
||||
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
|
||||
combine_weights = combine1_sec + combine2_sec
|
||||
gate_all = einsum("ks,kse->kse", gate_s, mask_float)
|
||||
locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits)
|
||||
combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc)
|
||||
combine_weights = torch.sum(combine_sec, dim=0)
|
||||
dispatch_mask = combine_weights.bool()
|
||||
|
||||
return l_aux, combine_weights, dispatch_mask, exp_counts
|
||||
|
@ -347,18 +352,12 @@ class TopKGate(Module):
|
|||
self.eval_capacity_factor = eval_capacity_factor
|
||||
self.min_capacity = min_capacity
|
||||
self.noisy_gate_policy = noisy_gate_policy
|
||||
self.wall_clock_breakdown = False
|
||||
self.gate_time = 0.0
|
||||
self.drop_tokens = drop_tokens
|
||||
self.use_rts = use_rts
|
||||
|
||||
def forward(
|
||||
self, inputs: torch.Tensor, used_token: torch.Tensor = None
|
||||
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("TopKGate").start()
|
||||
|
||||
# input jittering
|
||||
if self.noisy_gate_policy == "Jitter" and self.training:
|
||||
inputs = multiplicative_jitter(inputs, device=inputs.device)
|
||||
|
@ -380,10 +379,6 @@ class TopKGate(Module):
|
|||
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
|
||||
)
|
||||
|
||||
if self.wall_clock_breakdown:
|
||||
timer("TopKGate").stop()
|
||||
self.gate_time = timer("TopKGate").elapsed(reset=False)
|
||||
|
||||
return gate_output
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue