|
|
|
@ -45,9 +45,13 @@ class MoeRouter(nn.Module, ABC):
|
|
|
|
|
self._z_loss = None
|
|
|
|
|
self.use_kernel = use_kernel
|
|
|
|
|
|
|
|
|
|
def get_capacity(self, logits_shape):
|
|
|
|
|
def get_capacity(self, num_tokens, num_experts, ep_group=None):
|
|
|
|
|
if ep_group is not None:
|
|
|
|
|
num_tokens_tensor = torch.tensor(num_tokens, device=get_current_device())
|
|
|
|
|
dist.all_reduce(num_tokens_tensor, group=ep_group)
|
|
|
|
|
num_tokens = num_tokens_tensor.item() // dist.get_world_size(ep_group)
|
|
|
|
|
capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_eval
|
|
|
|
|
capacity = math.floor(self.k_value * capacity_factor * logits_shape[-2] / logits_shape[-1])
|
|
|
|
|
capacity = math.floor(self.k_value * capacity_factor * num_tokens / num_experts)
|
|
|
|
|
capacity += capacity % 2
|
|
|
|
|
capacity = max(capacity, self.min_capacity)
|
|
|
|
|
assert capacity > 0
|
|
|
|
@ -175,7 +179,8 @@ class Top1Router(MoeRouter):
|
|
|
|
|
assert inputs.dtype == torch.float
|
|
|
|
|
probs = F.softmax(inputs, dim=-1)
|
|
|
|
|
num_experts = probs.size(-1)
|
|
|
|
|
capacity = self.get_capacity(inputs.shape)
|
|
|
|
|
num_tokens = inputs.size(0)
|
|
|
|
|
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
|
|
|
|
|
|
|
|
|
top1_idx = torch.argmax(inputs, dim=-1)
|
|
|
|
|
mask = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
@ -276,7 +281,8 @@ class Top2Router(MoeRouter):
|
|
|
|
|
probs = probs / routing_weights.sum(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
num_experts = probs.size(-1)
|
|
|
|
|
capacity = self.get_capacity(inputs.shape)
|
|
|
|
|
num_tokens = inputs.size(0)
|
|
|
|
|
capacity = self.get_capacity(num_tokens, num_experts, ep_group)
|
|
|
|
|
|
|
|
|
|
top1_idx = torch.argmax(probs, dim=-1)
|
|
|
|
|
mask1 = F.one_hot(top1_idx, num_classes=num_experts).to(torch.int32)
|
|
|
|
|