merge operand if noisy_gate_policy is not used

pull/521/head
Wenwen Qu 2023-11-28 16:17:49 +08:00 committed by Qu Wenwen
parent 95263fa1d0
commit 3443ab1f5b
3 changed files with 34 additions and 16 deletions

View File

@ -144,6 +144,7 @@ model = dict(
num_experts=8,
moe_use_residual=False,
moe_gate_k=2,
moe_noisy_gate_policy="RSample",
)
"""
zero1 parallel:

View File

@ -306,6 +306,8 @@ def args_sanity_check():
model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2)
if "moe_noisy_gate_policy" not in model:
model._add_item("moe_noisy_gate_policy", None)
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)

View File

@ -248,28 +248,40 @@ def top1gating(
return l_aux, combine_weights, dispatch_mask, exp_counts
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
def top2gating(
logits: Tensor,
capacity_factor: float,
min_capacity: int,
noisy_gate_policy: Optional[str] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
# everything is in fp32 in this function
gates = F.softmax(logits, dim=1)
num_experts = int(gates.shape[1])
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# NOTE: here we just add noise on 2nd expert, following
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/moe/top2gate.py
if noisy_gate_policy == "RSample":
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# merge operands in topk gating to save launch overhead
masks = torch.cat((mask1, mask2), dim=0)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
# merge operands in topk gating to save launch overhead
masks = torch.cat((mask1, mask2), dim=0)
else:
# Create a mask by top-2 experts
indices_s = torch.topk(gates, 2, dim=1).indices
indices_s = indices_s.permute(1, 0).reshape(-1)
masks = F.one_hot(indices_s, num_classes=num_experts)
# Compute locations in capacity buffer
locations = torch.cumsum(masks, dim=0) - 1
@ -376,7 +388,10 @@ class TopKGate(Module):
else:
gate_output = top2gating(
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
logits,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
self.noisy_gate_policy,
)
return gate_output