From 3cb3fc275e39152a7bcc60eb6b65dfc614e91456 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 21 Mar 2022 11:18:55 +0800 Subject: [PATCH] zero init ctx receives a dp process group (#471) --- colossalai/zero/init_ctx/init_context.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 6e1466df1..2a43d240d 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -1,11 +1,15 @@ import functools +from typing import Optional import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 +from torch.distributed import ProcessGroup # Inserts _post_init_method at the end of init method @@ -103,8 +107,9 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): shard_strategy: BaseShardStrategy, shard_param: bool = False, shard_grad: bool = False, - rm_torch_payload_on_the_fly=False, - model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int)): + rm_torch_payload_on_the_fly: bool = False, + model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.int), + dp_process_group: Optional[ProcessGroup] = None): super().__init__() self.convert_fp16 = convert_fp16 self.target_device = target_device @@ -115,6 +120,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.rm_torch_payload_on_the_fly = False self.initialized_param_list = [] self.model_numel_tensor = model_numel_tensor + self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) def _post_context_exec(self): """The callback function when the context exits. @@ -154,10 +160,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.initialized_param_list.append(param) if self.shard_param: - self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) + self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload) # if param.col_attr.grad and self.shard_grad: - # self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) + # self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) # We must cast buffers # If we use BN, buffers may be on CPU and Float