diff --git a/colossalai/gemini/memory_tracer/model_data_memtracer.py b/colossalai/gemini/memory_tracer/model_data_memtracer.py
index 98228892d..c228bdff4 100644
--- a/colossalai/gemini/memory_tracer/model_data_memtracer.py
+++ b/colossalai/gemini/memory_tracer/model_data_memtracer.py
@@ -106,4 +106,15 @@ class ModelDataTracer(metaclass=SingletonMeta):
         return self._get_mem_usage()
 
 
+class CudaMemInfo(metaclass=SingletonMeta):
+
+    def __init__(self) -> None:
+        self.model_data_list = []
+        self.non_model_data_list = []
+        self.unreleased_grad_flag = {}
+        self.unreleased_grad_volume = 0
+
+
 GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
+
+GLOBAL_CUDA_MEM_INFO = CudaMemInfo()
\ No newline at end of file
diff --git a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py
index 50cc1451e..f69df73e3 100644
--- a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py
+++ b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py
@@ -1,11 +1,9 @@
 import torch.nn
 
 from colossalai.tensor.param_op_hook import ParamOpHookManager
-from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
-from colossalai.gemini.tensor_utils import free_storage
+from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook, GradHook
+from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
 from colossalai.nn.parallel.data_parallel import _cast_float
-from functools import partial
-
 
 __all__ = ['ParamTracerWrapper']
 
@@ -15,22 +13,33 @@ class ParamTracerWrapper():
         super().__init__()
         self.module = module
         self.dtype = dtype
-        self.param_op_hook = ParamTracerHook(dtype)
+        self.param_op_hook = ParamTracerHook()
+        self.grad_hook = GradHook(module)
+        self.cpu_param_data_dict = {}
 
         for p in module.parameters():
             p.data = p.data.to(dtype)
-            if p.requires_grad:
-                p.register_hook(partial(self.grad_handle))
 
         self._cast_buffers_to_cuda_dtype()
 
     def __call__(self, *args, **kwargs):
         return self.forward(*args, **kwargs)
 
-    def grad_handle(self, grad):
-        free_storage(grad)
+    def _save_param_data_on_cpu(self):
+        for p in self.module.parameters():
+            self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu")
+            self.cpu_param_data_dict[p].copy_(p.data)
+
+    def _restore_param_data(self):
+        for p in self.module.parameters():
+            p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad)
+            p.data.copy_(self.cpu_param_data_dict[p])
+        self.cpu_param_data_dict.clear()
 
     def _pre_forward(self):
+        self._clear_cuda_mem_info()
+        self._save_param_data_on_cpu()
+        self.grad_hook.register_grad_hook()
         self.param_op_hook.mem_monitor.start()
 
     def forward(self, *args, **kwargs):
@@ -48,8 +57,16 @@ class ParamTracerWrapper():
 
     def _post_backward(self):
         cuda_volume = self.param_op_hook.mem_monitor.finish()
-        last_model_data = self.param_op_hook._model_data_list[-1]
-        self.param_op_hook._non_model_data_list.append(cuda_volume - last_model_data)
+        last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
+        GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
+        self.grad_hook.remove_grad_hook()
+        self._restore_param_data()
+
+    def _clear_cuda_mem_info(self):
+        GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
+        GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear()
+        GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear()
+        GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0
 
     def _cast_buffers_to_cuda_dtype(self):
         for buffer in self.module.buffers():
diff --git a/colossalai/gemini/ophooks/param_trace_hook.py b/colossalai/gemini/ophooks/param_trace_hook.py
index aef2cdbd7..678927d78 100644
--- a/colossalai/gemini/ophooks/param_trace_hook.py
+++ b/colossalai/gemini/ophooks/param_trace_hook.py
@@ -8,6 +8,7 @@ import torch
 from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
 from colossalai.tensor.param_op_hook import ParamOpHook
 from colossalai.gemini.tensor_utils import free_storage, alloc_storage
+from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
 
 
 class TrainingPhase(Enum):
@@ -15,42 +16,69 @@ class TrainingPhase(Enum):
     BACKWARD = 1
 
 
+class GradHook():
+    def __init__(self, module: torch.nn.Module):
+        self.module = module
+        self.grad_hook_list = []
+
+    def grad_handle(self, p, grad):
+        assert GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]
+        free_storage(grad)
+        GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume -= grad.numel() * grad.element_size()
+        GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
+
+    def register_grad_hook(self):
+        for p in self.module.parameters():
+            if p.requires_grad:
+                self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
+                GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
+
+    def remove_grad_hook(self):
+        for hook in self.grad_hook_list:
+            hook.remove()
+
+
 class ParamTracerHook(ParamOpHook):
 
-    def __init__(self, dtype: torch.dtype = torch.half) -> None:
+    def __init__(self) -> None:
         super().__init__()
         self._training_phase = TrainingPhase.FORWARD
         self.mem_monitor = SyncCudaMemoryMonitor()
-        self._non_model_data_list = []
-        self._model_data_list = []
-        self.dtype = dtype
 
     def _free_cuda_params(self, params):
         for p in params:
+            if p.data.device.type == "cpu":
+                raise NotImplementedError("Only free cuda memory")
             free_storage(p.data)
 
     def _allocate_params_on_cuda(self, params):
         for p in params:
             cur_dev = p.data.device.type
             if cur_dev == "cpu":
-                # p.data = p.data.to("cuda")
-                p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
+                if p.grad is not None and p.grad.device.type == "cpu":
+                    raise NotImplementedError("Only run in forward propagation")
+                p.data = torch.empty(p.data.shape, device="cuda", dtype=p.data.dtype,
+                                     requires_grad=p.data.requires_grad)
             elif cur_dev == "cuda":
                 alloc_storage(p.data)
 
     def sample_model_data(self, params):
-        data_volume = 0
+        data_volume = GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume
         for p in params:
-            data_volume += p.data.numel() * p.data.element_size()
-        if self._training_phase == TrainingPhase.BACKWARD:
-            # add param.grad, actually param.grad is None in this time
-            data_volume *= 2
-        self._model_data_list.append(data_volume)
+            cur_model_data_volume = p.data.numel() * p.data.element_size()
+            data_volume += cur_model_data_volume
+            if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
+                # add param.grad, actually param.grad is None in this time
+                data_volume += cur_model_data_volume
+                if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]:
+                    GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume
+                    GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True
+        GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume)
 
     def pre_op(self, params):
         cuda_volume = self.mem_monitor.finish()
-        if len(self._model_data_list):
-            self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
+        if len(GLOBAL_CUDA_MEM_INFO.model_data_list):
+            GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1])
         self._allocate_params_on_cuda(params)
         self.sample_model_data(params)
         self.mem_monitor.start()
diff --git a/tests/test_gemini/test_param_tracer.py b/tests/test_gemini/test_param_tracer.py
index d82778271..7e4c6dff5 100644
--- a/tests/test_gemini/test_param_tracer.py
+++ b/tests/test_gemini/test_param_tracer.py
@@ -2,6 +2,7 @@ import numpy as np
 import torch
 
 from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper
+from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
 from colossalai.utils.model.colo_init_context import ColoInitContext
 from tests.components_to_test.registry import non_distributed_component_funcs
 
@@ -35,9 +36,9 @@ def run_param_wrapper_testing():
 
             run_fwd_bwd(model, data, label, criterion, False)
 
-        cuda_non_model_data_list = np.array(model.param_op_hook._non_model_data_list) / 1024 ** 2
+        cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024 ** 2
         print("cuda_non_model_data_list", len(cuda_non_model_data_list))
-        # print(model.param_op_hook._non_model_data_list)
+        # print(GLOBAL_CUDA_MEM_INFO.non_model_data_list)
 
         del model