merge operands in topk gating

pull/521/head
Wenwen Qu 2023-11-28 14:52:50 +08:00 committed by Qu Wenwen
parent 06e8301861
commit 95263fa1d0
1 changed files with 32 additions and 37 deletions

View File

@ -97,23 +97,32 @@ def einsum(rule, a, b):
if USE_EINSUM:
return torch.einsum(rule, a, b)
elif rule == "s,se->se":
# [1, s] * [s, e]
# [s, 1] * [s, e]
return a.reshape(a.shape[0], -1) * b
elif rule == "ks,kse->kse":
# [k, s, 1] * [s, e]
return a.reshape(a.shape[0], a.shape[1], -1) * b
elif rule == "se,sc->sec":
# [s,e,1] * [s,1,c]
return a.unsqueeze(2) * b.unsqueeze(1)
elif rule == "kse,ksc->ksec":
# [k,s,e,1] * [k,s,1,c]
return a.unsqueeze(3) * b.unsqueeze(2)
elif rule == "se,se->s":
# [s,1,e] * [s,e,1]
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1)
elif rule == "se,kse->ks":
# [s,1,e] * [k,s,e,1]
return torch.matmul(a.unsqueeze(1), b.unsqueeze(3)).reshape(b.shape[0], -1)
elif rule == "sec,sm->ecm":
# [e*c, s] * [s, m]
# [e*c, s] @ [s, m]
s = a.shape[0]
e = a.shape[1]
c = a.shape[2]
m = b.shape[1]
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m)
elif rule == "sec,ecm->sm":
# [s, e*c] * [e*c, m]
# [s, e*c] @ [e*c, m]
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1]))
elif rule == "ks,ksm->sm":
k = b.shape[0]
@ -259,47 +268,43 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# merge operands in topk gating to save launch overhead
masks = torch.cat((mask1, mask2), dim=0)
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
locations2 = torch.cumsum(mask2, dim=0) - 1
# Update 2nd's location by accounting for locations of 1st
locations2 += torch.sum(mask1, dim=0, keepdim=True)
locations = torch.cumsum(masks, dim=0) - 1
# reshape (s,e) to (k,s,e)
masks = masks.reshape(-1, gates.shape[0], num_experts)
locations = locations.reshape(-1, gates.shape[0], num_experts)
# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
exp_counts = torch.sum(masks[0], dim=0).detach().to("cpu")
# Compute l_aux
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.type_as(logits), dim=0)
ce = torch.mean(masks[0].type_as(logits), dim=0)
l_aux = torch.mean(me * ce) * num_experts * num_experts
# Remove locations outside capacity from mask
mask1 *= torch.lt(locations1, capacity)
mask2 *= torch.lt(locations2, capacity)
masks *= torch.lt(locations, capacity)
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations2_s = torch.sum(locations2 * mask2, dim=1)
locations_s = torch.sum(locations * masks, dim=2)
# Normalize gate probabilities
mask1_float = mask1.type_as(logits)
mask2_float = mask2.type_as(logits)
gates1_s = einsum("se,se->s", gates, mask1_float)
gates2_s = einsum("se,se->s", gates, mask2_float)
denom_s = gates1_s + gates2_s
mask_float = masks.type_as(logits)
gate_s = einsum("se,kse->ks", gates, mask_float)
denom_s = torch.sum(gate_s, dim=0)
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gates1_s /= denom_s
gates2_s /= denom_s
gate_s /= denom_s
# Calculate combine_weights and dispatch_mask
gates1 = einsum("s,se->se", gates1_s, mask1_float)
gates2 = einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).type_as(logits)
locations2_sc = F.one_hot(locations2_s, num_classes=capacity).type_as(logits)
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
gate_all = einsum("ks,kse->kse", gate_s, mask_float)
locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits)
combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc)
combine_weights = torch.sum(combine_sec, dim=0)
dispatch_mask = combine_weights.bool()
return l_aux, combine_weights, dispatch_mask, exp_counts
@ -347,18 +352,12 @@ class TopKGate(Module):
self.eval_capacity_factor = eval_capacity_factor
self.min_capacity = min_capacity
self.noisy_gate_policy = noisy_gate_policy
self.wall_clock_breakdown = False
self.gate_time = 0.0
self.drop_tokens = drop_tokens
self.use_rts = use_rts
def forward(
self, inputs: torch.Tensor, used_token: torch.Tensor = None
) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
if self.wall_clock_breakdown:
timer("TopKGate").start()
# input jittering
if self.noisy_gate_policy == "Jitter" and self.training:
inputs = multiplicative_jitter(inputs, device=inputs.device)
@ -380,10 +379,6 @@ class TopKGate(Module):
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
)
if self.wall_clock_breakdown:
timer("TopKGate").stop()
self.gate_time = timer("TopKGate").elapsed(reset=False)
return gate_output