diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py
index 59ee5f9bd..ba3c188ac 100644
--- a/colossalai/nn/optimizer/cpu_adam.py
+++ b/colossalai/nn/optimizer/cpu_adam.py
@@ -56,6 +56,8 @@ class CPUAdam(torch.optim.Optimizer):
                           bias_correction2,
                           loss_scale,
                           use_adamw=False):
+        # FIXME(ver217): remove the below line when replace torch adam with fused adam
+        grad = grad.float()
         if loss_scale is not None:
             grad.div_(loss_scale)
 
diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py
index 70e14548b..afce76933 100644
--- a/colossalai/zero/sharded_model/sharded_model_v2.py
+++ b/colossalai/zero/sharded_model/sharded_model_v2.py
@@ -29,24 +29,22 @@ class ShardedModelV2(nn.Module):
     compared to classic data parallelism while the computational granularity and communication efficiency are retained.
     Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
 
-    :param module: A sharded module, which must be initialized by `ZeroInitContext`.
-    :type module: nn.Module
-    :param shard_strategy: A shard strategy to manage shard behavior.
-    :type shard_strategy: BaseShardStrategy
-    :param process_group: Data parallel process group, defaults to None
-    :type process_group: Optional[ProcessGroup], optional
-    :param reduce_scatter_process_group: Reduce-scatter process group, defaults to None. Generally, it should be `None`.
-    :type reduce_scatter_process_group: Optional[ProcessGroup], optional
-    :param reduce_scatter_bucket_size_mb: Reduce-scatter bucket size in *MB*, defaults to 25
-    :type reduce_scatter_bucket_size_mb: int, optional
-    :param fp32_reduce_scatter: If set to `True`, gradients are forced to FP32 before reduce-scatter, defaults to False
-    :type fp32_reduce_scatter: bool, optional
-    :param offload_config: We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload, defaults to None
-    :type offload_config: Optional[dict], optional
-    :param gradient_predivide_factor: Gradient is divived by this value before reduce-scatter, defaults to 1.0
-    :type gradient_predivide_factor: Optional[float], optional
-    :param use_memory_tracer: Whether to use memoty tracer, defaults to False
-    :type use_memory_tracer: bool, optional
+    Args:
+        module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
+        shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
+        process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
+        reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group. 
+            Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
+        reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
+        fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
+        offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None.
+        gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
+        use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False.
+        reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad. 
+            Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation. 
+            In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). 
+            We find that PyTorch's optimizers don't support mixed precision, 
+            so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
     """
 
     def __init__(self,
@@ -58,7 +56,8 @@ class ShardedModelV2(nn.Module):
                  fp32_reduce_scatter: bool = False,
                  offload_config: Optional[dict] = None,
                  gradient_predivide_factor: Optional[float] = 1.0,
-                 use_memory_tracer: bool = False):
+                 use_memory_tracer: bool = False,
+                 reuse_fp16_shard: bool = False):
         super().__init__()
         self.logger = get_dist_logger()
 
@@ -97,8 +96,8 @@ class ShardedModelV2(nn.Module):
         self.fp32_reduce_scatter = fp32_reduce_scatter
         self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
         for param in module.parameters():
-            # Init `offload_fp32_grad`
-            param.col_attr.offload_fp32_grad = self._cpu_offload
+            # Init `offload_grad`
+            param.col_attr.offload_grad = self._cpu_offload
 
         # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
         # So we use 1.0 as the default gradient_predivide_factor
@@ -114,6 +113,7 @@ class ShardedModelV2(nn.Module):
         self._require_backward_grad_sync: bool = True
 
         self._cuda_margin_space = 0
+        self.reuse_fp16_shard = reuse_fp16_shard
 
     @property
     def cuda_margin_space(self):
@@ -143,11 +143,7 @@ class ShardedModelV2(nn.Module):
         for ophook in self._ophook_list:
             ophook.post_iter()
 
-    @torch.no_grad()
-    def _post_backward_operations(self) -> None:
-        """
-        The method includes operations required to be processed after backward
-        """
+    def _update_memstats(self):
         if self._iter_cnter == 0 and self._memstats_collector:
             self._memstats_collector.finish_collection()
         if self._memstats_collector:
@@ -160,6 +156,13 @@ class ShardedModelV2(nn.Module):
 
         self._iter_cnter += 1
 
+    @torch.no_grad()
+    def _post_backward_operations(self) -> None:
+        """
+        The method includes operations required to be processed after backward
+        """
+        self._update_memstats()
+
         if self._require_backward_grad_sync:
             # Flush any unreduced buckets in the post_backward stream.
             with torch.cuda.stream(self.comm_stream):
@@ -171,9 +174,11 @@ class ShardedModelV2(nn.Module):
         self.reducer.free()
         # In case some post bwd hook is not fired
         if self.shard_param:
+            tensor_list = []
             for p in self.module.parameters():
                 if not p.col_attr.param_is_sharded:
-                    self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
+                    tensor_list.append(p.col_attr.sharded_data_tensor)
+            self.shard_strategy.shard(tensor_list, self.process_group)
         for p in self.module.parameters():
             p.col_attr.bwd_count = 0
             if not p.requires_grad:
@@ -191,13 +196,17 @@ class ShardedModelV2(nn.Module):
             # If world size == 1 and sharded param,
             # the shape `grad` is the same as unsharded param
             # So we can just use `view(-1)` to ensure grad is a flat tensor shard
-            grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
-            if p.col_attr.offload_fp32_grad:
+            if self.reuse_fp16_shard:
+                grad = p.col_attr.sharded_data_tensor.payload
+            else:
+                grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
+            if p.col_attr.offload_grad:
                 col_move_to_cpu(grad)
             if p.col_attr.fp32_grad is not None:
+                assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
                 p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
                 grad = p.col_attr.fp32_grad
-            p.grad.data = grad.view(-1)
+            p.grad.data = grad
             p.col_attr.fp16_grad = None
             p.col_attr.fp32_grad = None
 
@@ -250,11 +259,15 @@ class ShardedModelV2(nn.Module):
         return empty_grad
 
     def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
+        reduced_grad = reduced_grad.view(-1)
         if self.gradient_postdivide_factor > 1:
             # Average grad by world_size for consistency with PyTorch DDP.
             reduced_grad.data.div_(self.gradient_postdivide_factor)
-
-        param.col_attr.fp16_grad = reduced_grad.data
+        if self.reuse_fp16_shard:
+            param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
+            param.col_attr.sharded_data_tensor.is_sharded = True
+        else:
+            param.col_attr.fp16_grad = reduced_grad.data
 
     def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
         self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py
index 2109d4499..b3507c132 100644
--- a/colossalai/zero/sharded_optim/sharded_optim_v2.py
+++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py
@@ -224,5 +224,5 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
                     if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
                         self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
                         p.grad.data = p.grad.data.to(torch.cuda.current_device())
-                        p.col_attr.offload_fp32_grad = False
+                        p.col_attr.offload_grad = False
                         fp32_shards_used_cuda_margin_mem += shard_mem
diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py
index 2a777d14e..5826dde96 100644
--- a/colossalai/zero/sharded_param/sharded_param.py
+++ b/colossalai/zero/sharded_param/sharded_param.py
@@ -14,7 +14,7 @@ class ShardedParamV2(object):
         self.fp16_grad: Optional[torch.Tensor] = None
         self.fp32_grad: Optional[torch.Tensor] = None
         # This attribute must be initialized in ShardedModel
-        self.offload_fp32_grad: bool = False
+        self.offload_grad: bool = False
 
         # make sure the shared param is the only owner of payload
         # The param.data maybe used to init the other part of the model.
diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py
index 7e6f881dc..70166e121 100644
--- a/tests/test_zero_data_parallel/common.py
+++ b/tests/test_zero_data_parallel/common.py
@@ -16,7 +16,8 @@ _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
                           offload_config=None,
                           gradient_predivide_factor=1.0,
                           use_memory_tracer=False,
-                          shard_strategy=TensorShardStrategy())
+                          shard_strategy=TensorShardStrategy(),
+                          reuse_fp16_shard=False)
 
 _ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
                               initial_scale=2**5,
@@ -116,10 +117,13 @@ def check_params_padding(model, zero_model, loose=False):
         assert allclose(p, zero_p, loose=loose)
 
 
-def check_sharded_params_padding(model, zero_model, loose=False):
+def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False):
     rank = dist.get_rank()
     for p, zero_p in zip(model.parameters(), zero_model.parameters()):
-        zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
+        if reuse_fp16_shard:
+            zero_p = zero_p.data.to(p.device).float()
+        else:
+            zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
         chunks = torch.flatten(p).chunk(dist.get_world_size())
         if rank >= len(chunks):
             continue
diff --git a/tests/test_zero_data_parallel/test_sharded_optim_v2.py b/tests/test_zero_data_parallel/test_sharded_optim_v2.py
index 6de799c80..a8d9c0874 100644
--- a/tests/test_zero_data_parallel/test_sharded_optim_v2.py
+++ b/tests/test_zero_data_parallel/test_sharded_optim_v2.py
@@ -18,7 +18,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
 from tests.components_to_test.registry import non_distributed_component_funcs
 from torch.nn.parallel import DistributedDataParallel as DDP
 
-from common import CONFIG, check_sharded_params_padding
+from common import CONFIG, check_sharded_model_params
 
 
 def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@@ -65,7 +65,8 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
         zero_model = ShardedModelV2(zero_model,
                                     shard_strategy,
                                     offload_config=dict(device='cpu') if cpu_offload else None,
-                                    use_memory_tracer=gpu_margin_mem_ratio > 0.0)
+                                    use_memory_tracer=gpu_margin_mem_ratio > 0.0,
+                                    reuse_fp16_shard=use_cpuadam)
 
         model = model_builder(checkpoint=True).half()
         col_model_deepcopy(zero_model, model)
@@ -92,7 +93,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
             data, label = data.cuda(), label.cuda()
             _run_step(apex_model, apex_optimizer, data, label, criterion, False)
             _run_step(zero_model, sharded_optim, data, label, criterion, False)
-            check_sharded_params_padding(model, zero_model, loose=True)
+            check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam)
             for param in model.parameters():
                 assert not has_inf_or_nan(param)
 
diff --git a/tests/test_zero_data_parallel/test_zero_engine.py b/tests/test_zero_data_parallel/test_zero_engine.py
index 56ad85203..c1fb6b2bb 100644
--- a/tests/test_zero_data_parallel/test_zero_engine.py
+++ b/tests/test_zero_data_parallel/test_zero_engine.py
@@ -16,7 +16,7 @@ from colossalai.zero.sharded_optim._utils import has_inf_or_nan
 from tests.components_to_test.registry import non_distributed_component_funcs
 from torch.nn.parallel import DistributedDataParallel as DDP
 
-from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_params_padding)
+from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params)
 
 
 def run_dist(rank, world_size, port, parallel_config):
@@ -87,7 +87,7 @@ def run_dist(rank, world_size, port, parallel_config):
         if parallel_config == MP_PARALLEL_CONFIG:
             check_params(torch_model, colo_model, loose=True)
         elif parallel_config == ZERO_PARALLEL_CONFIG:
-            check_sharded_params_padding(torch_model, colo_model, loose=True)
+            check_sharded_model_params(torch_model, colo_model, loose=True)
 
 
 # FIXME: enable this test in next PR