From a1a7899cae6940c1adcef09d115dcc15a4d97518 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 16 Jun 2022 17:17:27 +0800 Subject: [PATCH] [hotfix] fix zero init ctx numel (#1128) --- colossalai/zero/init_ctx/init_context.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 93d5a455f..f4142da08 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -78,6 +78,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): ZeroContextMgr().current_context = self + self.param_numel = {} + self.top_module = None + @property def target_device(self): return self.config.target_device @@ -169,11 +172,18 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): torch.set_rng_state(self.cpu_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state) + params = frozenset(self.top_module.parameters()) + for param in self.param_numel.keys(): + if param not in params: + self.param_numel[param] = 0 + self.model_numel_tensor.fill_(sum(self.param_numel.values())) + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times. """ + self.top_module = module def half_fn(t: torch.Tensor): return t.half() if t.is_floating_point() else t @@ -183,7 +193,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if hasattr(param, 'colo_attr'): continue - self.model_numel_tensor += param.numel() + self.param_numel[param] = param.numel() # convert parameters to half param_half = half_fn(param)