mirror of https://github.com/InternLM/InternLM
merge operand if noisy_gate_policy is not used
parent
95263fa1d0
commit
3443ab1f5b
|
@ -144,6 +144,7 @@ model = dict(
|
|||
num_experts=8,
|
||||
moe_use_residual=False,
|
||||
moe_gate_k=2,
|
||||
moe_noisy_gate_policy="RSample",
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
|
@ -306,6 +306,8 @@ def args_sanity_check():
|
|||
model._add_item("moe_use_residual", False)
|
||||
if "moe_gate_k" not in model:
|
||||
model._add_item("moe_gate_k", 2)
|
||||
if "moe_noisy_gate_policy" not in model:
|
||||
model._add_item("moe_noisy_gate_policy", None)
|
||||
# process the parallel config
|
||||
if "sequence_parallel" not in gpc.config.parallel:
|
||||
gpc.config.parallel._add_item("sequence_parallel", False)
|
||||
|
|
|
@ -248,28 +248,40 @@ def top1gating(
|
|||
return l_aux, combine_weights, dispatch_mask, exp_counts
|
||||
|
||||
|
||||
def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
def top2gating(
|
||||
logits: Tensor,
|
||||
capacity_factor: float,
|
||||
min_capacity: int,
|
||||
noisy_gate_policy: Optional[str] = None,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
"""Implements Top2Gating on logits."""
|
||||
# everything is in fp32 in this function
|
||||
gates = F.softmax(logits, dim=1)
|
||||
num_experts = int(gates.shape[1])
|
||||
|
||||
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity))
|
||||
|
||||
# Create a mask for 1st's expert per token
|
||||
indices1_s = torch.argmax(gates, dim=1)
|
||||
num_experts = int(gates.shape[1])
|
||||
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
||||
# NOTE: here we just add noise on 2nd expert, following
|
||||
# https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/moe/top2gate.py
|
||||
if noisy_gate_policy == "RSample":
|
||||
# Create a mask for 1st's expert per token
|
||||
indices1_s = torch.argmax(gates, dim=1)
|
||||
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
|
||||
|
||||
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
||||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||
# Replace top-expert with min value
|
||||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
|
||||
indices2_s = torch.argmax(logits_except1, dim=1)
|
||||
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
||||
|
||||
# merge operands in topk gating to save launch overhead
|
||||
masks = torch.cat((mask1, mask2), dim=0)
|
||||
# Create a mask for 2nd's expert per token using Gumbel-max trick
|
||||
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
|
||||
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
|
||||
# Replace top-expert with min value
|
||||
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), torch.finfo(logits.dtype).min)
|
||||
indices2_s = torch.argmax(logits_except1, dim=1)
|
||||
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
|
||||
# merge operands in topk gating to save launch overhead
|
||||
masks = torch.cat((mask1, mask2), dim=0)
|
||||
else:
|
||||
# Create a mask by top-2 experts
|
||||
indices_s = torch.topk(gates, 2, dim=1).indices
|
||||
indices_s = indices_s.permute(1, 0).reshape(-1)
|
||||
masks = F.one_hot(indices_s, num_classes=num_experts)
|
||||
|
||||
# Compute locations in capacity buffer
|
||||
locations = torch.cumsum(masks, dim=0) - 1
|
||||
|
@ -376,7 +388,10 @@ class TopKGate(Module):
|
|||
|
||||
else:
|
||||
gate_output = top2gating(
|
||||
logits, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
|
||||
logits,
|
||||
self.capacity_factor if self.training else self.eval_capacity_factor,
|
||||
self.min_capacity,
|
||||
self.noisy_gate_policy,
|
||||
)
|
||||
|
||||
return gate_output
|
||||
|
|
Loading…
Reference in New Issue