mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] remove duplicated param register to stateful tensor manager (#728)
parent
600e769a42
commit
7db3ccc79b
|
@ -106,11 +106,9 @@ class ShardedModelV2(nn.Module):
|
|||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
|
||||
# for param in module.parameters():
|
||||
for submodule in module.modules():
|
||||
for param in submodule.parameters(recurse=False):
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
for param in module.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue