diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py
index a324be8c5..028d0854c 100644
--- a/colossalai/zero/sharded_model/sharded_model_v2.py
+++ b/colossalai/zero/sharded_model/sharded_model_v2.py
@@ -106,11 +106,9 @@ class ShardedModelV2(nn.Module):
             GLOBAL_MODEL_DATA_TRACER.register_model(self)
             self._memstats_collector = MemStatsCollector()
             self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
-            # for param in module.parameters():
-            for submodule in module.modules():
-                for param in submodule.parameters(recurse=False):
-                    if hasattr(param, 'colo_attr'):
-                        self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
+            for param in module.parameters():
+                if hasattr(param, 'colo_attr'):
+                    self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
             self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
             self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
         else: