From 6b6002962a960ba18d23deaa58bc087159c31970 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 10 Mar 2022 16:31:02 +0800 Subject: [PATCH] [zero] zero init context collect numel of model (#375) --- colossalai/zero/init_ctx/init_context.py | 6 +++++- tests/test_zero_data_parallel/test_init_context.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 340206661..3cc32f49e 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -100,7 +100,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): shard_strategy: BaseShardStrategy, shard_param: bool = False, shard_grad: bool = False, - rm_torch_payload_on_the_fly=False): + rm_torch_payload_on_the_fly=False, + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int)): super().__init__() self.convert_fp16 = convert_fp16 self.target_device = target_device @@ -110,6 +111,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): # FIXME(jiaruifang) now setting it to True is invalid. self.rm_torch_payload_on_the_fly = False self.initialized_param_list = [] + self.model_numel_tensor = model_numel_tensor def _post_context_exec(self): """The callback function when the context exits. @@ -129,6 +131,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if hasattr(param, 'col_attr'): continue + self.model_numel_tensor += param.numel() + target_device = self.target_device # convert to fp16 if necessary diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index 0a2f0d960..335fa9933 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -23,10 +23,12 @@ def run_dist(rank, world_size, port, init_device): for get_components_func in non_distributed_component_funcs: model_builder, _, _, _, _ = get_components_func() + model_numel_tensor = torch.zeros(1, dtype=torch.int) with ZeroInitContext(convert_fp16=True, target_device=init_device, shard_strategy=TensorShardStrategy(), - shard_param=True): + shard_param=True, + model_numel_tensor=model_numel_tensor): model = model_builder(checkpoint=True) for param in model.parameters(): @@ -38,7 +40,7 @@ def run_dist(rank, world_size, port, init_device): print(f'cpu usgae {GLOBAL_MODEL_DATA_TRACER.cpu_usage}') print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}') - + print(f'numel {model_numel_tensor}') if init_device.type == 'cuda': assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) elif init_device.type == 'cpu':