|
|
|
@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module):
|
|
|
|
|
return F.linear(x, self.weight) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@no_shard_zero_decrator(is_replicated=True) |
|
|
|
|
class MoeLayer(nn.Module): |
|
|
|
|
"""A MoE layer, that puts its input tensor to its gate and uses the output logits |
|
|
|
|
to router all tokens, is mainly used to exchange all tokens for every expert across |
|
|
|
@ -241,12 +242,11 @@ class MoeLayer(nn.Module):
|
|
|
|
|
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
@no_shard_zero_decrator(is_replicated=True) |
|
|
|
|
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts): |
|
|
|
|
super().__init__() |
|
|
|
|
self.d_model = dim_model |
|
|
|
|
self.num_experts = num_experts |
|
|
|
|
self.gate = FP32LinearGate(dim_model, num_experts) |
|
|
|
|
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model)) |
|
|
|
|
self.router = router |
|
|
|
|
self.experts = experts |
|
|
|
|
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False |
|
|
|
@ -254,16 +254,14 @@ class MoeLayer(nn.Module):
|
|
|
|
|
self.ep_size = experts.dist_info.ep_size |
|
|
|
|
self.num_local_experts = experts.num_local_experts |
|
|
|
|
|
|
|
|
|
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model)) |
|
|
|
|
|
|
|
|
|
def a2a_process(self, dispatch_data: torch.Tensor): |
|
|
|
|
expert_input = AllToAll.apply(dispatch_data, self.ep_group) |
|
|
|
|
|
|
|
|
|
input_shape = expert_input.shape |
|
|
|
|
|
|
|
|
|
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model) |
|
|
|
|
|
|
|
|
|
expert_output = self.experts(expert_input) |
|
|
|
|
expert_output = expert_output.reshape(input_shape) |
|
|
|
|
|
|
|
|
|
expert_output = AllToAll.apply(expert_output, self.ep_group) |
|
|
|
|
return expert_output |
|
|
|
|
|
|
|
|
@ -274,16 +272,22 @@ class MoeLayer(nn.Module):
|
|
|
|
|
return expert_out |
|
|
|
|
|
|
|
|
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
|
|
|
|
# reshape the input tokens |
|
|
|
|
tokens = inputs.reshape(-1, self.d_model) |
|
|
|
|
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens |
|
|
|
|
gate_output = self.gate(fp32_input) |
|
|
|
|
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) |
|
|
|
|
|
|
|
|
|
# the data type of the inputs in the gating should be fp32 |
|
|
|
|
fp32_input = tokens.to(torch.float) |
|
|
|
|
fp32_weight = self.gate_weight.to(torch.float) |
|
|
|
|
gate_output = F.linear(fp32_input, fp32_weight) |
|
|
|
|
|
|
|
|
|
# the result from the router |
|
|
|
|
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group) |
|
|
|
|
|
|
|
|
|
if self.use_kernel: |
|
|
|
|
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:]) |
|
|
|
|
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:]) |
|
|
|
|
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model) |
|
|
|
|
else: |
|
|
|
|
sec_mask_f = router_res[1].type_as(inputs) |
|
|
|
|
sec_mask_f = route_result_list[1].type_as(inputs) |
|
|
|
|
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens) |
|
|
|
|
|
|
|
|
|
# dispatch_data [e, c, h] |
|
|
|
@ -295,12 +299,11 @@ class MoeLayer(nn.Module):
|
|
|
|
|
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts " |
|
|
|
|
"build function.") |
|
|
|
|
# expert_output [e, c, h] |
|
|
|
|
|
|
|
|
|
if self.use_kernel: |
|
|
|
|
expert_output = expert_output.reshape(-1, self.d_model) |
|
|
|
|
ans = MoeCombine.apply(expert_output, *router_res) |
|
|
|
|
ans = MoeCombine.apply(expert_output, *route_result_list) |
|
|
|
|
else: |
|
|
|
|
combine_weights = router_res[0].type_as(inputs) |
|
|
|
|
combine_weights = route_result_list[0].type_as(inputs) |
|
|
|
|
combine_weights = combine_weights.view(combine_weights.shape[0], -1) |
|
|
|
|
expert_output = expert_output.view(-1, expert_output.shape[-1]) |
|
|
|
|
ans = torch.matmul(combine_weights, expert_output) |
|
|
|
|