[hotfix] fix zero init ctx numel (#1128)

pull/1131/head
ver217 2022-06-16 17:17:27 +08:00 committed by GitHub
parent f0a954f16d
commit a1a7899cae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 1 deletions

View File

@ -78,6 +78,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
ZeroContextMgr().current_context = self
self.param_numel = {}
self.top_module = None
@property
def target_device(self):
return self.config.target_device
@ -169,11 +172,18 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
params = frozenset(self.top_module.parameters())
for param in self.param_numel.keys():
if param not in params:
self.param_numel[param] = 0
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
@ -183,7 +193,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if hasattr(param, 'colo_attr'):
continue
self.model_numel_tensor += param.numel()
self.param_numel[param] = param.numel()
# convert parameters to half
param_half = half_fn(param)