diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/gemini/memory_tracer/memstats_collector.py
index 07ef74fbf..0885fb168 100644
--- a/colossalai/gemini/memory_tracer/memstats_collector.py
+++ b/colossalai/gemini/memory_tracer/memstats_collector.py
@@ -1,6 +1,7 @@
 from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
 from colossalai.utils.memory import colo_device_memory_used
 from colossalai.gemini.stateful_tensor import StatefulTensor
+from colossalai.tensor import ChunkManager
 
 import torch
 import time
@@ -128,3 +129,19 @@ class MemStatsCollector:
         self._start_flag = False
         self._step_idx = 0
         self._step_total = 0
+
+
+class MemStatsCollectorV2(MemStatsCollector):
+
+    def __init__(self, chunk_manager: ChunkManager) -> None:
+        super().__init__()
+        self._chunk_manager = chunk_manager
+
+    def sample_model_data(self) -> None:
+        """Sampling model data statistics.
+        """
+        if self._start_flag:
+            cuda_mem = self._chunk_manager.total_mem['cuda']
+            cpu_mem = self._chunk_manager.total_mem['cpu']
+            self._model_data_cuda_list.append(cuda_mem)
+            self._model_data_cpu_list.append(cpu_mem)
diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py
index e5b687248..63d275954 100644
--- a/colossalai/nn/parallel/data_parallel.py
+++ b/colossalai/nn/parallel/data_parallel.py
@@ -60,7 +60,7 @@ class ColoDDP(torch.nn.Module):
             else:
                 ColoDDP._save_grad(p, grad)
             return empty_grad
-            
+
         else:
             group = gpc.get_cpu_group(ParallelMode.DATA)
             dist.all_reduce(grad, group=group)
@@ -113,7 +113,7 @@ class ColoDDPV2(ColoDDP):
     def _post_backward(self):
         self.chunk_manager.exec_lazy_release()
         for p in self.module.parameters():
-            if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
+            if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
                 p.grad = None
             else:
                 p.grad = p.data
@@ -137,8 +137,8 @@ class ColoDDPV2(ColoDDP):
                 grad = grad / self.dp_world_size
             self.chunk_manager.copy_tensor_to_chunk_slice(p, grad)
             chunk = self.chunk_manager.get_chunk(p)
-            reduced = self.chunk_manager.reduce_chunk(p)
-            self.chunk_manager.release_chunk(p)
+            reduced = self.chunk_manager.reduce_chunk(chunk)
+            self.chunk_manager.release_chunk(chunk)
             if reduced and not chunk.is_free:
                 self.overflow_counter += chunk.has_inf_or_nan
         return empty_grad
diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py
index b9f1ad466..86a51282b 100644
--- a/colossalai/tensor/chunk.py
+++ b/colossalai/tensor/chunk.py
@@ -2,7 +2,7 @@ import torch
 import torch.distributed as dist
 from dataclasses import dataclass
 from enum import Enum
-from typing import Optional, Dict, Deque, Set, List
+from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
 from collections import deque
 from colossalai.core import global_context as gpc
 from colossalai.context import ParallelMode
@@ -172,6 +172,12 @@ class Chunk:
     def device_type(self) -> str:
         return self.data.device.type
 
+    def __hash__(self) -> int:
+        return hash(id(self))
+
+    def __eq__(self, __o: object) -> bool:
+        return self is __o
+
 
 class ChunkManager:
 
@@ -226,8 +232,7 @@ class ChunkManager:
             src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA)
         return src_rank
 
-    def access_chunk(self, tensor: torch.Tensor) -> None:
-        chunk = self.tensor_chunk_map[tensor]
+    def access_chunk(self, chunk: Chunk) -> None:
         if chunk in self.accessed_chunks:
             return
         if not chunk.is_free:
@@ -236,10 +241,9 @@ class ChunkManager:
         self.accessed_chunks.add(chunk)
         self.total_mem[chunk.device_type] += chunk.mem
 
-    def release_chunk(self, tensor: torch.Tensor) -> None:
+    def release_chunk(self, chunk: Chunk) -> None:
         if not self.enable_distributed_storage:
             return
-        chunk = self.tensor_chunk_map[tensor]
         if chunk not in self.accessed_chunks:
             return
         if chunk.can_release:
@@ -248,8 +252,7 @@ class ChunkManager:
             if chunk.is_free:
                 self.total_mem[chunk.device_type] -= chunk.mem
 
-    def move_chunk(self, tensor: torch.Tensor, device: torch.device) -> None:
-        chunk = self.tensor_chunk_map[tensor]
+    def move_chunk(self, chunk: Chunk, device: torch.device) -> None:
         if chunk.data.device == device:
             return
         if chunk.can_move_device and not chunk.is_free:
@@ -261,8 +264,7 @@ class ChunkManager:
         chunk = self.tensor_chunk_map[tensor]
         chunk.tensor_trans_state(tensor, state)
 
-    def reduce_chunk(self, tensor: torch.Tensor) -> bool:
-        chunk = self.tensor_chunk_map[tensor]
+    def reduce_chunk(self, chunk: Chunk) -> bool:
         if not chunk.can_reduce:
             return False
         self.total_mem[chunk.device_type] -= chunk.mem
@@ -274,10 +276,6 @@ class ChunkManager:
         chunk = self.tensor_chunk_map[tensor]
         chunk.copy_tensor_to_chunk_slice(tensor, data)
 
-    def is_chunk_free(self, tensor: torch.Tensor) -> bool:
-        chunk = self.tensor_chunk_map[tensor]
-        return chunk.is_free
-
     def get_chunk(self, tensor: torch.Tensor) -> Chunk:
         return self.tensor_chunk_map[tensor]
 
@@ -285,8 +283,8 @@ class ChunkManager:
         self.lazy_release_tensors.extend(tensors)
 
     def exec_lazy_release(self) -> None:
-        for tensor in self.lazy_release_tensors:
-            self.release_chunk(tensor)
+        for chunk in self.get_chunks(self.lazy_release_tensors):
+            self.release_chunk(chunk)
         self.lazy_release_tensors.clear()
 
     def __repr__(self) -> str:
@@ -340,3 +338,23 @@ class ChunkManager:
         for dest_chunk, src_chunk in zip(self.chunk_groups[dest_group_name], self.chunk_groups[src_group_name]):
             if not dest_chunk.is_free:
                 dest_chunk.copy_(src_chunk)
+
+    def get_chunks(self, tensors: Iterable[torch.Tensor]) -> Tuple[Chunk, ...]:
+        chunks = []
+        for tensor in tensors:
+            chunk = self.get_chunk(tensor)
+            if chunk not in chunks:
+                chunks.append(chunk)
+        return tuple(chunks)
+
+    def add_extern_static_tensor(self, tensor: torch.Tensor) -> None:
+        """Add extern static tensor to chunk manager. 
+        Those tensors won't be managed by chunk manager, but we want to monitor memory usage of them.
+        They are "static", which means their shape, dtype, device never change.
+        Thus, their memory usage never changes.
+
+        Args:
+            tensor (torch.Tensor): An extern static tensor. E.g. optimizer state.
+        """
+        assert tensor not in self.tensor_chunk_map
+        self.total_mem[tensor.device.type] += tensor.numel() * tensor.element_size()
diff --git a/colossalai/zero/utils/zero_hook_v2.py b/colossalai/zero/utils/zero_hook_v2.py
index 7b44ca1a6..737feedc4 100644
--- a/colossalai/zero/utils/zero_hook_v2.py
+++ b/colossalai/zero/utils/zero_hook_v2.py
@@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
         self._training_phase = TrainingPhase.FORWARD
 
     def pre_op(self, params):
+        chunks = self._chunk_manager.get_chunks(params)
         for p in params:
             self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
         self._chunk_manager.exec_lazy_release()
         # TODO: evict chunks
-        for p in params:
-            self._chunk_manager.access_chunk(p)
+        for chunk in chunks:
+            self._chunk_manager.access_chunk(chunk)
 
     def post_op(self, params):
         for p in params:
diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py
index 8cc21515d..0f84b4d9f 100644
--- a/colossalai/zero/zero_optimizer.py
+++ b/colossalai/zero/zero_optimizer.py
@@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
     def _update_params_ptr(self):
         for group in self.optim.param_groups:
             for p in group['params']:
-                if not self.module.chunk_manager.is_chunk_free(p):
+                if not self.module.chunk_manager.get_chunk(p).is_free:
                     p.data = self.fp16_param_to_fp32_param[p]
                 else:
                     assert p.grad is None
diff --git a/tests/test_tensor/test_chunk.py b/tests/test_tensor/test_chunk.py
index 515fb710c..f1d508d83 100644
--- a/tests/test_tensor/test_chunk.py
+++ b/tests/test_tensor/test_chunk.py
@@ -32,7 +32,7 @@ HAS_TENSORS = {
     }
 }
 
-TOTAL_MEM = {True: {True: [8192, 8192], False: [16384, 16384]}, False: {True: [8192, 4096], False: [12288, 12288]}}
+TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, 256], False: [768, 768]}}
 
 
 @parameterize('use_chunk', [False, True])
@@ -41,8 +41,8 @@ def run_chunk_zero(use_chunk, use_zero):
     rank = gpc.get_local_rank(ParallelMode.DATA)
     if rank == 0:
         print(f'use_chunk={use_chunk}, use_zero={use_zero}')
-    params = [torch.rand(32, 32) for _ in range(3)]
-    chunk_size = 2048 if use_chunk else None
+    params = [torch.rand(8, 8) for _ in range(3)]
+    chunk_size = 128 if use_chunk else None
     chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
     assert chunk_manager.total_mem['cpu'] == 0
     assert chunk_manager.total_mem['cuda'] == 0
@@ -51,18 +51,19 @@ def run_chunk_zero(use_chunk, use_zero):
     check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
     assert chunk_manager.total_mem['cpu'] == 0
     assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank]
-    for p in params:
-        chunk_manager.access_chunk(p)
+    chunks = chunk_manager.get_chunks(params)
+    for chunk in chunks:
+        chunk_manager.access_chunk(chunk)
     check_has_params(params, [True, True, True])
     assert chunk_manager.total_mem['cpu'] == 0
     assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][False][rank]
-    for p in params:
-        chunk_manager.release_chunk(p)
+    for chunk in chunks:
+        chunk_manager.release_chunk(chunk)
     check_has_params(params, HAS_TENSORS[use_chunk][use_zero][rank])
     assert chunk_manager.total_mem['cpu'] == 0
     assert chunk_manager.total_mem['cuda'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
-    for p in params:
-        chunk_manager.move_chunk(p, torch.device('cpu'))
+    for chunk in chunks:
+        chunk_manager.move_chunk(chunk, torch.device('cpu'))
     assert chunk_manager.total_mem['cpu'] == TOTAL_MEM[use_chunk][use_zero][rank], chunk_manager.total_mem['cuda']
     assert chunk_manager.total_mem['cuda'] == 0