mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix zero init ctx numel (#1128)
parent
f0a954f16d
commit
a1a7899cae
|
@ -78,6 +78,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
ZeroContextMgr().current_context = self
|
ZeroContextMgr().current_context = self
|
||||||
|
|
||||||
|
self.param_numel = {}
|
||||||
|
self.top_module = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def target_device(self):
|
def target_device(self):
|
||||||
return self.config.target_device
|
return self.config.target_device
|
||||||
|
@ -169,11 +172,18 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
torch.set_rng_state(self.cpu_rng_state)
|
torch.set_rng_state(self.cpu_rng_state)
|
||||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||||
|
|
||||||
|
params = frozenset(self.top_module.parameters())
|
||||||
|
for param in self.param_numel.keys():
|
||||||
|
if param not in params:
|
||||||
|
self.param_numel[param] = 0
|
||||||
|
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
|
||||||
|
|
||||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
The function to call at the end of the constructor of each module.
|
The function to call at the end of the constructor of each module.
|
||||||
NOTE() The module may be passed to this function multiple times.
|
NOTE() The module may be passed to this function multiple times.
|
||||||
"""
|
"""
|
||||||
|
self.top_module = module
|
||||||
|
|
||||||
def half_fn(t: torch.Tensor):
|
def half_fn(t: torch.Tensor):
|
||||||
return t.half() if t.is_floating_point() else t
|
return t.half() if t.is_floating_point() else t
|
||||||
|
@ -183,7 +193,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
if hasattr(param, 'colo_attr'):
|
if hasattr(param, 'colo_attr'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.model_numel_tensor += param.numel()
|
self.param_numel[param] = param.numel()
|
||||||
|
|
||||||
# convert parameters to half
|
# convert parameters to half
|
||||||
param_half = half_fn(param)
|
param_half = half_fn(param)
|
||||||
|
|
Loading…
Reference in New Issue