From f7f2248771d7d0cfea05cedfc6f03f77513518fd Mon Sep 17 00:00:00 2001
From: HELSON <c2h214748@gmail.com>
Date: Thu, 22 Sep 2022 13:56:30 +0800
Subject: [PATCH] [moe] fix MoE bugs (#1628)

* remove forced FP32 modules

* correct no_shard-contexts' positions
---
 colossalai/nn/layer/moe/experts.py       |  2 +-
 colossalai/nn/layer/moe/layers.py        | 31 +++++++++++++-----------
 colossalai/zero/init_ctx/init_context.py |  3 ++-
 tests/test_moe/test_kernel.py            |  7 +++---
 tests/test_moe/test_moe_zero_init.py     |  8 +-----
 tests/test_moe/test_moe_zero_optim.py    |  6 -----
 tests/test_zero/common.py                |  2 +-
 7 files changed, 26 insertions(+), 33 deletions(-)

diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py
index 367278d4d..055afded9 100644
--- a/colossalai/nn/layer/moe/experts.py
+++ b/colossalai/nn/layer/moe/experts.py
@@ -24,6 +24,7 @@ class MoeExperts(nn.Module):
         self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
 
 
+@no_shard_zero_decrator(is_replicated=False)
 class Experts(MoeExperts):
     """A wrapper class to create experts. It will create E experts across the
     moe model parallel group, where E is the number of experts. Every expert
@@ -35,7 +36,6 @@ class Experts(MoeExperts):
         expert_args: Args used to initialize experts, the args could be found in corresponding expert class
     """
 
-    @no_shard_zero_decrator(is_replicated=False)
     def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
         super().__init__("all_to_all", num_experts)
 
diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py
index aaa261b23..d308c1253 100644
--- a/colossalai/nn/layer/moe/layers.py
+++ b/colossalai/nn/layer/moe/layers.py
@@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module):
         return F.linear(x, self.weight)
 
 
+@no_shard_zero_decrator(is_replicated=True)
 class MoeLayer(nn.Module):
     """A MoE layer, that puts its input tensor to its gate and uses the output logits
     to router all tokens, is mainly used to exchange all tokens for every expert across
@@ -241,12 +242,11 @@ class MoeLayer(nn.Module):
         experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
     """
 
-    @no_shard_zero_decrator(is_replicated=True)
     def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
         super().__init__()
         self.d_model = dim_model
         self.num_experts = num_experts
-        self.gate = FP32LinearGate(dim_model, num_experts)
+        self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
         self.router = router
         self.experts = experts
         self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
@@ -254,16 +254,14 @@ class MoeLayer(nn.Module):
         self.ep_size = experts.dist_info.ep_size
         self.num_local_experts = experts.num_local_experts
 
+        nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
+
     def a2a_process(self, dispatch_data: torch.Tensor):
         expert_input = AllToAll.apply(dispatch_data, self.ep_group)
-
         input_shape = expert_input.shape
-
         expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
-
         expert_output = self.experts(expert_input)
         expert_output = expert_output.reshape(input_shape)
-
         expert_output = AllToAll.apply(expert_output, self.ep_group)
         return expert_output
 
@@ -274,16 +272,22 @@ class MoeLayer(nn.Module):
         return expert_out
 
     def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+        # reshape the input tokens
         tokens = inputs.reshape(-1, self.d_model)
-        fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
-        gate_output = self.gate(fp32_input)
-        router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
+
+        # the data type of the inputs in the gating should be fp32
+        fp32_input = tokens.to(torch.float)
+        fp32_weight = self.gate_weight.to(torch.float)
+        gate_output = F.linear(fp32_input, fp32_weight)
+
+        # the result from the router
+        route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
 
         if self.use_kernel:
-            dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
+            dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
             dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
         else:
-            sec_mask_f = router_res[1].type_as(inputs)
+            sec_mask_f = route_result_list[1].type_as(inputs)
             dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
 
         # dispatch_data [e, c, h]
@@ -295,12 +299,11 @@ class MoeLayer(nn.Module):
             raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
                                       "build function.")
         # expert_output [e, c, h]
-
         if self.use_kernel:
             expert_output = expert_output.reshape(-1, self.d_model)
-            ans = MoeCombine.apply(expert_output, *router_res)
+            ans = MoeCombine.apply(expert_output, *route_result_list)
         else:
-            combine_weights = router_res[0].type_as(inputs)
+            combine_weights = route_result_list[0].type_as(inputs)
             combine_weights = combine_weights.view(combine_weights.shape[0], -1)
             expert_output = expert_output.view(-1, expert_output.shape[-1])
             ans = torch.matmul(combine_weights, expert_output)
diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py
index f4142da08..572ddd9e4 100644
--- a/colossalai/zero/init_ctx/init_context.py
+++ b/colossalai/zero/init_ctx/init_context.py
@@ -258,7 +258,8 @@ def no_shard_zero_decrator(is_replicated: bool = True):
 
         def _no_shard(*args, **kwargs):
             with no_shard_zero_context(is_replicated):
-                init_func(*args, **kwargs)
+                ret = init_func(*args, **kwargs)
+            return ret
 
         return _no_shard
 
diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py
index e5b5aa68d..bd87a3f58 100644
--- a/tests/test_moe/test_kernel.py
+++ b/tests/test_moe/test_kernel.py
@@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
     expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
     expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
     layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
+    layer = layer.to(get_current_device())
     if data_type == torch.float16:
         layer = layer.half()
 
@@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
 
     # save all results
     o_tk_grad = tokens.grad.data.clone()
-    o_gt_grad = layer.gate.weight.grad.data.clone()
+    o_gt_grad = layer.gate_weight.grad.data.clone()
 
     # reset all gradients
     tokens.grad.zero_()
-    layer.gate.weight.grad.zero_()
+    layer.gate_weight.grad.zero_()
 
     layer.use_kernel = True
     new_out = layer(tokens)    # get ouputs through colossal kernel
@@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
 
     new_out.backward(grad)    # get new type gradient
     n_tk_grad = tokens.grad.data.clone()
-    n_gt_grad = layer.gate.weight.grad.data.clone()
+    n_gt_grad = layer.gate_weight.grad.data.clone()
 
     if data_type == torch.float32:
         check_equal(o_tk_grad, n_tk_grad)
diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py
index b6bc08006..b5746f562 100644
--- a/tests/test_moe/test_moe_zero_init.py
+++ b/tests/test_moe/test_moe_zero_init.py
@@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
     for name, param in model.named_parameters():
         assert hasattr(param, 'colo_attr')
 
-        # the weights in the gate should be fp32
-        if 'gate' in name:
-            assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
-        else:
-            assert param.colo_attr.sharded_data_tensor.dtype == torch.half
-
         # the parameters in moe experts and its gate should not be sharded
         if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
-            assert not param.colo_attr.sharded_data_tensor.is_sharded
+            assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
         else:
             assert param.colo_attr.sharded_data_tensor.is_sharded
 
diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py
index 08a36cb36..afc6ba5f7 100644
--- a/tests/test_moe/test_moe_zero_optim.py
+++ b/tests/test_moe/test_moe_zero_optim.py
@@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
     apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
     apex_grad_handler = MoeGradientHandler(model)
 
-    # Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
-    for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
-        if 'gate' in n:
-            p.data = p.float()
-            p.data.copy_(zp.colo_attr.data_payload)
-
     for i, (data, label) in enumerate(train_dataloader):
         if i > 5:
             break
diff --git a/tests/test_zero/common.py b/tests/test_zero/common.py
index 5d2ff173f..bc6cd75a6 100644
--- a/tests/test_zero/common.py
+++ b/tests/test_zero/common.py
@@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
         else:
             zero_p = zero_p.colo_attr.data_payload.to(p.device)
 
-        assert p.dtype == zero_p.dtype
+        assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
         assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'