mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix zero init ctx numel (#1128)
parent
f0a954f16d
commit
a1a7899cae
|
@ -78,6 +78,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
self.param_numel = {}
|
||||
self.top_module = None
|
||||
|
||||
@property
|
||||
def target_device(self):
|
||||
return self.config.target_device
|
||||
|
@ -169,11 +172,18 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
torch.set_rng_state(self.cpu_rng_state)
|
||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||
|
||||
params = frozenset(self.top_module.parameters())
|
||||
for param in self.param_numel.keys():
|
||||
if param not in params:
|
||||
self.param_numel[param] = 0
|
||||
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
self.top_module = module
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
|
@ -183,7 +193,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
if hasattr(param, 'colo_attr'):
|
||||
continue
|
||||
|
||||
self.model_numel_tensor += param.numel()
|
||||
self.param_numel[param] = param.numel()
|
||||
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
|
|
Loading…
Reference in New Issue