diff --git a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
index 633e2d63f..eb854ce93 100644
--- a/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
+++ b/colossalai/kernel/cuda_native/csrc/multi_tensor_adam.cu
@@ -22,7 +22,7 @@ typedef enum
 
 using MATH_T = float;
 
-template <typename T>
+template <typename T_g, typename T_p>
 struct AdamFunctor
 {
     __device__ __forceinline__ void operator()(
@@ -50,16 +50,16 @@ struct AdamFunctor
         int chunk_idx = tl.block_to_chunk[blockIdx.x];
         int n = tl.sizes[tensor_loc];
 
-        T *g = (T *)tl.addresses[0][tensor_loc];
+        T_g *g = (T_g *)tl.addresses[0][tensor_loc];
         g += chunk_idx * chunk_size;
 
-        T *p = (T *)tl.addresses[1][tensor_loc];
+        T_p *p = (T_p *)tl.addresses[1][tensor_loc];
         p += chunk_idx * chunk_size;
 
-        T *m = (T *)tl.addresses[2][tensor_loc];
+        T_p *m = (T_p *)tl.addresses[2][tensor_loc];
         m += chunk_idx * chunk_size;
 
-        T *v = (T *)tl.addresses[3][tensor_loc];
+        T_p *v = (T_p *)tl.addresses[3][tensor_loc];
         v += chunk_idx * chunk_size;
 
         n -= chunk_idx * chunk_size;
@@ -155,15 +155,15 @@ void multi_tensor_adam_cuda(
         bias_correction2 = 1 - std::pow(beta2, step);
     }
 
-    // Assume single type across p,g,m1,m2 now
-    DISPATCH_DOUBLE_FLOAT_AND_HALF(
-        tensor_lists[0][0].scalar_type(), 0, "adam",
+    DISPATCH_FLOAT_AND_HALF_FOR_G_P(
+        tensor_lists[0][0].scalar_type(),
+        tensor_lists[1][0].scalar_type(), 0, "adam",
         multi_tensor_apply<4>(
             BLOCK_SIZE,
             chunk_size,
             noop_flag,
             tensor_lists,
-            AdamFunctor<scalar_t_0>(),
+            AdamFunctor<g_scalar_t_0, p_scalar_t_0>(),
             beta1,
             beta2,
             bias_correction1,
diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h
index 2cae99a2d..cf83414af 100644
--- a/colossalai/kernel/cuda_native/csrc/type_shim.h
+++ b/colossalai/kernel/cuda_native/csrc/type_shim.h
@@ -173,6 +173,36 @@
         AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
     }
 
+#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...)                          \
+    if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float)                        \
+    {                                                                                            \
+        using g_scalar_t_##LEVEL = float;                                                        \
+        using p_scalar_t_##LEVEL = float;                                                        \
+        __VA_ARGS__;                                                                             \
+    }                                                                                            \
+    else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half)                    \
+    {                                                                                            \
+        using g_scalar_t_##LEVEL = float;                                                        \
+        using p_scalar_t_##LEVEL = at::Half;                                                     \
+        __VA_ARGS__;                                                                             \
+    }                                                                                            \
+    else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float)                    \
+    {                                                                                            \
+        using g_scalar_t_##LEVEL = at::Half;                                                     \
+        using p_scalar_t_##LEVEL = float;                                                        \
+        __VA_ARGS__;                                                                             \
+    }                                                                                            \
+    else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half)                     \
+    {                                                                                            \
+        using g_scalar_t_##LEVEL = at::Half;                                                     \
+        using p_scalar_t_##LEVEL = at::Half;                                                     \
+        __VA_ARGS__;                                                                             \
+    }                                                                                            \
+    else                                                                                         \
+    {                                                                                            \
+       AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'");          \
+    }                                                                                            \
+
 template <typename T>
 __device__ __forceinline__ T reduce_block_into_lanes(T *x,
                                                      T val,
diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py
index e5c17c3e1..465e000a1 100644
--- a/colossalai/nn/optimizer/fused_adam.py
+++ b/colossalai/nn/optimizer/fused_adam.py
@@ -10,7 +10,7 @@ class FusedAdam(torch.optim.Optimizer):
     """Implements Adam algorithm.
 
     Currently GPU-only.  Requires ColossalAI to be installed via
-    ``pip install -v --no-cache-dir --global-option="--cuda_ext" ./``.
+    ``pip install .``.
 
     This version of fused Adam implements 2 fusions.
 
@@ -18,7 +18,7 @@ class FusedAdam(torch.optim.Optimizer):
       * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
 
     :class:`colossalai.nn.optimizer.FusedAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
-    or ``torch.optim.Adam`` with ``adam_w_mode=False``
+    or ``torch.optim.Adam`` with ``adamw_mode=False``
 
     :class:`colossalai.nn.optimizer.FusedAdam` may be used with or without Amp. 
 
@@ -36,7 +36,7 @@ class FusedAdam(torch.optim.Optimizer):
         amsgrad (boolean, optional): whether to use the AMSGrad variant of this
             algorithm from the paper `On the Convergence of Adam and Beyond`_
             (default: False) NOT SUPPORTED in FusedAdam!
-        adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
+        adamw_mode (boolean, optional): Apply L2 regularization or weight decay
             True for decoupled weight decay(also known as AdamW) (default: True)
         set_grad_none (bool, optional): whether set grad to None when zero_grad()
             method is called. (default: True)
@@ -53,7 +53,7 @@ class FusedAdam(torch.optim.Optimizer):
                  bias_correction=True,
                  betas=(0.9, 0.999),
                  eps=1e-8,
-                 adam_w_mode=True,
+                 adamw_mode=True,
                  weight_decay=0.,
                  amsgrad=False,
                  set_grad_none=True):
@@ -62,7 +62,7 @@ class FusedAdam(torch.optim.Optimizer):
             raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
         defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay)
         super(FusedAdam, self).__init__(params, defaults)
-        self.adam_w_mode = 1 if adam_w_mode else 0
+        self.adamw_mode = 1 if adamw_mode else 0
         self.set_grad_none = set_grad_none
         if multi_tensor_applier.available:
             import colossal_C
@@ -109,8 +109,7 @@ class FusedAdam(torch.optim.Optimizer):
                 group['step'] = 1
 
             # create lists for multi-tensor apply
-            g_16, p_16, m_16, v_16 = [], [], [], []
-            g_32, p_32, m_32, v_32 = [], [], [], []
+            g_l, p_l, m_l, v_l = [], [], [], []
 
             for p in group['params']:
                 if p.grad is None:
@@ -127,26 +126,16 @@ class FusedAdam(torch.optim.Optimizer):
                     # Exponential moving average of squared gradient values
                     state['exp_avg_sq'] = torch.zeros_like(p.data)
 
-                if p.dtype == torch.float16:
-                    g_16.append(p.grad.data)
-                    p_16.append(p.data)
-                    m_16.append(state['exp_avg'])
-                    v_16.append(state['exp_avg_sq'])
-                elif p.dtype == torch.float32:
-                    g_32.append(p.grad.data)
-                    p_32.append(p.data)
-                    m_32.append(state['exp_avg'])
-                    v_32.append(state['exp_avg_sq'])
-                else:
+                if p.dtype not in [torch.float16, torch.float32]:
                     raise RuntimeError('FusedAdam only support fp16 and fp32.')
+                
+                g_l.append(p.grad.data)
+                p_l.append(p.data)
+                m_l.append(state['exp_avg'])
+                v_l.append(state['exp_avg_sq'])
 
-            if (len(g_16) > 0):
-                multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16],
-                                     group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
-                                     bias_correction, group['weight_decay'])
-            if (len(g_32) > 0):
-                multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32],
-                                     group['lr'], beta1, beta2, group['eps'], group['step'], self.adam_w_mode,
-                                     bias_correction, group['weight_decay'])
+            multi_tensor_applier(self.multi_tensor_adam, self._dummy_overflow_buf, [g_l, p_l, m_l, v_l],
+                                    group['lr'], beta1, beta2, group['eps'], group['step'], self.adamw_mode,
+                                    bias_correction, group['weight_decay'])
 
         return loss
diff --git a/tests/test_optimizer/unittest_cpu_adam.py b/tests/test_optimizer/unittest_cpu_adam.py
index 401fc5241..55ba74f7e 100644
--- a/tests/test_optimizer/unittest_cpu_adam.py
+++ b/tests/test_optimizer/unittest_cpu_adam.py
@@ -1,38 +1,7 @@
-# BSD 3-Clause License
-#
-# Copyright (C) 2021 THL A29 Limited, a Tencent company.  All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without modification,
-# are permitted provided that the following conditions are met:
-#
-#  * Redistributions of source code must retain the above copyright notice, this
-#    list of conditions and the following disclaimer.
-#
-#  * Redistributions in binary form must reproduce the above copyright notice,
-#    this list of conditions and the following disclaimer in the documentation
-#    and/or other materials provided with the distribution.
-#
-#  * Neither the name of the psutil authors nor the names of its contributors
-#    may be used to endorse or promote products derived from this software without
-#    specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
-# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
-# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
-# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
-# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
-# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
-# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
 import math
 import torch
-try:
-    import cpu_adam
-except ImportError:
-    raise ImportError("import cpu_adam error")
+
+from colossalai.testing import parameterize
 
 
 def torch_adam_update(
@@ -71,45 +40,46 @@ def torch_adam_update(
     param.addcdiv_(exp_avg, denom, value=-step_size)
 
 
-class Test():
+def assertLess(data_diff, threshold, msg):
+    assert data_diff < threshold, msg
 
-    def __init__(self):
-        self.opt_id = 0
 
-    def assertLess(self, data_diff, threshold, msg):
-        assert data_diff < threshold, msg
+def assertTrue(condition, msg):
+    assert condition, msg
 
-    def assertTrue(self, condition, msg):
-        assert condition, msg
 
-    def check_res(
-        self,
-        step,
-        lr,
-        eps,
-        beta1,
-        beta2,
-        weight_decay,
-        shape,
-        grad_dtype,
-        loss_scale,
-        use_adamw,
-        cpu_adam_op,
-    ):
-        p_data = torch.rand(shape, dtype=grad_dtype)
+@parameterize('adamw', [True, False])
+@parameterize('step', [1, 2])
+@parameterize('loss_scale', [-1, 2 ** 5])
+@parameterize('p_dtype', [torch.float, torch.half])
+@parameterize('g_dtype', [torch.float, torch.half])
+def test_cpu_adam(adamw, step, loss_scale, p_dtype, g_dtype):
+    lr = 1e-3
+    beta1, beta2 = 0.9, 0.999
+    eps = 1e-8
+    weight_decay = 0
+    
+    for i in range(1024):
+        p_data = torch.rand(64, dtype=p_dtype)
         p_data_copy = p_data.clone().float()
-        p_grad = torch.rand(shape, dtype=grad_dtype)
+        p_grad = torch.rand(64, dtype=g_dtype)
         if loss_scale > 0:
             p_grad.mul_(loss_scale)
         p_grad_copy = p_grad.clone().float()
-        exp_avg = torch.rand(shape)
+        exp_avg = torch.rand(p_data.shape)
         exp_avg_copy = exp_avg.clone()
-        exp_avg_sq = torch.rand(shape)
+        exp_avg_sq = torch.rand(p_data.shape)
         exp_avg_sq_copy = exp_avg_sq.clone()
 
-        cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, use_adamw, True)
+        try:
+            import cpu_adam
+            cpu_adam_op = cpu_adam
+        except:
+            raise ImportError("...")
+
+        cpu_adam_op.create_adam(0, lr, beta1, beta2, eps, weight_decay, adamw, False)
         cpu_adam_op.adam_update(
-            self.opt_id,
+            0,
             step,
             lr,
             beta1,
@@ -136,62 +106,24 @@ class Test():
             exp_avg_copy,
             exp_avg_sq_copy,
             loss_scale,
-            use_adamw,
+            adamw,
         )
-
         if loss_scale > 0:
             p_grad.div_(loss_scale)
-
         var = p_data_copy - p_data
         data_diff = torch.max(torch.abs(var))
-        threshold = 2e-3 if grad_dtype else 1e-4
-        self.assertLess(
+        threshold = 1e-3
+        print(f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
+            f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}")
+        assertLess(
             data_diff,
             threshold,
-            f"p_data diff {data_diff}. failed check, step {step}, lr {lr} eps "
-            f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} loss_scale {loss_scale} grad_dtype {grad_dtype}",
+            f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, loss_scale {loss_scale}, eps "
+            f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
         )
         max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
-        self.assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
+        assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
         max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
-        self.assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
+        assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
         max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
-        self.assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
-
-    def test_cpu_adam(self):
-        lr = 0.9
-        eps = 1e-6
-        weight_decay = 0
-        for use_adamw in [False, True]:
-            for shape in [(23,), (8, 24)]:
-                for step in range(1, 2):
-                    for lr in [0.01]:
-                        for eps in [1e-8]:
-                            for beta1 in [0.9]:
-                                for beta2 in [0.999]:
-                                    for weight_decay in [0.001]:
-                                        for grad_dtype in [torch.half, torch.float]:
-                                            for loss_scale in [-1, 2**5]:
-                                                self.check_res(
-                                                    step,
-                                                    lr,
-                                                    eps,
-                                                    beta1,
-                                                    beta2,
-                                                    weight_decay,
-                                                    shape,
-                                                    grad_dtype,
-                                                    loss_scale,
-                                                    use_adamw,
-                                                    cpu_adam,
-                                                )
-
-
-def test_cpu_adam():
-    test_case = Test()
-    test_case.test_cpu_adam()
-
-
-if __name__ == "__main__":
-    test = Test()
-    test.test_cpu_adam()
+        assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
diff --git a/tests/test_optimizer/unittest_fused_adam.py b/tests/test_optimizer/unittest_fused_adam.py
new file mode 100644
index 000000000..f3b2b2a7b
--- /dev/null
+++ b/tests/test_optimizer/unittest_fused_adam.py
@@ -0,0 +1,61 @@
+import torch
+import torch.nn as nn
+from torch.optim.adam import Adam
+from torch.optim import AdamW
+
+from colossalai.nn.optimizer.fused_adam import FusedAdam
+from colossalai.testing import parameterize
+
+
+class FC(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.fc = nn.Sequential(nn.Linear(64, 64))
+    def forward(self, x):
+        return self.fc(x)
+
+
+@parameterize('adamw', [False, True])
+@parameterize('p_dtype', [torch.float, torch.half])
+@parameterize('g_dtype', [torch.float, torch.half])
+def test_adam(adamw, p_dtype, g_dtype):
+    model = FC().cuda().to(p_dtype)
+    state = model.state_dict()
+    model_copy = FC().cuda().to(p_dtype)
+    model_copy.load_state_dict(state.copy())
+
+    if adamw:
+        optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
+        torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
+    else:
+        optim = FusedAdam(model.parameters(), lr=1e-3)
+        torch_optim = Adam(model_copy.parameters(), lr=1e-3)
+
+    data = torch.rand(1024, 64).cuda().to(p_dtype)
+    data_copy = data.clone()
+    label = torch.rand(1024, 64).cuda().to(p_dtype)
+
+    for d, l in zip(data, label):
+        y = model(d)
+        loss = ((l - y) ** 2).sum()
+        optim.zero_grad()
+        loss.backward()
+        if p_dtype != g_dtype:
+            for i in range(len(optim.param_groups[0]['params'])):
+                optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
+        optim.step()
+
+    for d, l in zip(data_copy, label):
+        y = model_copy(d)
+        loss = ((l - y) ** 2).sum()
+        torch_optim.zero_grad()
+        loss.backward()
+        torch_optim.step()
+
+    assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
+    
+    for i in range(len(optim.param_groups[0]['params'])):
+        if torch.isnan(optim.param_groups[0]['params'][i]).any() \
+           or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
+            continue
+        assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)
diff --git a/tests/test_optimizer/unittest_fused_adam_kernel.py b/tests/test_optimizer/unittest_fused_adam_kernel.py
new file mode 100644
index 000000000..7de256e78
--- /dev/null
+++ b/tests/test_optimizer/unittest_fused_adam_kernel.py
@@ -0,0 +1,98 @@
+from numpy import dtype
+import torch
+import torch.nn as nn
+
+import math
+
+from colossalai.testing import parameterize
+from colossalai.utils import multi_tensor_applier
+
+def torch_adam_update(
+    step,
+    lr,
+    beta1,
+    beta2,
+    eps,
+    weight_decay,
+    param,
+    grad,
+    exp_avg,
+    exp_avg_sq,
+    loss_scale,
+    use_adamw,
+):
+    if loss_scale > 0:
+        grad.div_(loss_scale)
+    bias_correction1 = 1 - beta1**step
+    bias_correction2 = 1 - beta2**step
+
+    if weight_decay != 0:
+        if use_adamw:
+            # Perform stepweight decay
+            param.mul_(1 - lr * weight_decay)
+        else:
+            grad = grad.add(param, alpha=weight_decay)
+
+    # Decay the first and second moment running average coefficient
+    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+
+    step_size = lr / bias_correction1
+
+    param.addcdiv_(exp_avg, denom, value=-step_size)
+
+
+@parameterize('adamw', [False, True])
+@parameterize('step', [1, 2])
+@parameterize('p_dtype', [torch.float, torch.half])
+@parameterize('g_dtype', [torch.float, torch.half])
+def test_adam(adamw, step, p_dtype, g_dtype):
+    try:
+        import colossal_C
+        fused_adam = colossal_C.multi_tensor_adam
+        dummy_overflow_buf = torch.cuda.IntTensor([0])
+    except:
+        raise ImportError("No colossal_C kernel installed.")
+    
+    count = 0
+
+    for i in range(1024):
+        p = torch.rand(64, dtype=p_dtype).cuda()
+        p_copy = p.clone().float()
+        g = torch.rand(p.shape, dtype=g_dtype).cuda()
+        g_copy = g.clone().float()
+        m = torch.rand(p.shape).cuda()
+        m_copy = m.clone()
+        v = torch.rand(p.shape).cuda()
+        v_copy = v.clone()
+
+        lr = 1e-3
+        beta1, beta2 = 0.9, 0.999
+        eps = 1e-8
+        weight_decay = 0
+
+        multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]],
+                                        lr, beta1, beta2, eps, step, adamw,
+                                        True, weight_decay)
+
+        torch_adam_update(
+                step,
+                lr,
+                beta1,
+                beta2,
+                eps,
+                weight_decay,
+                p_copy,    # fp32 data
+                g_copy,    # fp32 grad
+                m_copy,
+                v_copy,
+                -1,
+                adamw,
+            )
+        
+        if torch.isnan(p).any() or torch.isnan(p_copy).any():
+            count += 1
+            continue
+        assert count < 200, "too many nans"
+        assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"