diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index a324be8c5..028d0854c 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -106,11 +106,9 @@ class ShardedModelV2(nn.Module): GLOBAL_MODEL_DATA_TRACER.register_model(self) self._memstats_collector = MemStatsCollector() self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector) - # for param in module.parameters(): - for submodule in module.modules(): - for param in submodule.parameters(recurse=False): - if hasattr(param, 'colo_attr'): - self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) + for param in module.parameters(): + if hasattr(param, 'colo_attr'): + self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) else: