From 297b8baae2ca8aaa42f81f646d8342b1961b66ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Fri, 1 Apr 2022 10:15:52 +0800 Subject: [PATCH] [model checkpoint] add gloo groups for cpu tensor communication (#589) --- colossalai/context/parallel_context.py | 34 +++++++++++++++++-- .../initializer_1d.py | 5 ++- .../initializer_2d.py | 10 ++++-- .../initializer_2p5d.py | 20 ++++++++--- .../initializer_3d.py | 15 ++++++-- .../initializer_data.py | 5 ++- .../initializer_model.py | 6 +++- .../initializer_pipeline.py | 10 +++--- .../initializer_sequence.py | 11 ++++-- .../initializer_tensor.py | 5 ++- 10 files changed, 98 insertions(+), 23 deletions(-) diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index 1148e905c..959eb4a9a 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -34,6 +34,7 @@ class ParallelContext(metaclass=SingletonMeta): self._local_ranks = dict() self._world_sizes = dict() self._groups = dict() + self._cpu_groups = dict() self._ranks_in_group = dict() # load config from file @@ -290,6 +291,32 @@ class ParallelContext(metaclass=SingletonMeta): self._check_parallel_mode(parallel_mode) self._groups[parallel_mode] = group + def get_cpu_group(self, parallel_mode: ParallelMode): + """Returns the Gloo group of the current device for `parallel_mode`. + + :param parallel_mode: The chosen parallel mode + :type parallel_mode: :class:`colossalai.context.ParallelMode` + :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode` + :return: The group of the current device for `parallel_mode` + :rtype: torch.distributed.ProcessGroup + """ + self._check_parallel_mode(parallel_mode) + return self._cpu_groups[parallel_mode] + + def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): + """Adds the Gloo group of the current device for `parallel_mode`. + + :param parallel_mode: The chosen parallel mode + :type parallel_mode: :class:`colossalai.context.ParallelMode` + :param group: The group to be added + :type group: torch.distributed.ProcessGroup + :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance + of :class:`colossalai.context.ParallelMode` + """ + self._check_parallel_mode(parallel_mode) + self._cpu_groups[parallel_mode] = group + def get_ranks_in_group(self, parallel_mode: ParallelMode): """Returns the rank of the current device for `parallel_mode` in the group. @@ -335,13 +362,16 @@ class ParallelContext(metaclass=SingletonMeta): dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) # None will give the default global process group for pytorch dist operations - self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL) + ranks = list(range(world_size)) + cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None + self._register_dist(rank, world_size, None, cpu_group, ranks, ParallelMode.GLOBAL) self.add_global_rank(ParallelMode.GLOBAL, rank) - def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode): + def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode): self.add_local_rank(mode, local_rank) self.add_world_size(mode, world_size) self.add_group(mode, process_group) + self.add_cpu_group(mode, cpu_group) self.add_ranks_in_group(mode, ranks_in_group) def check_sanity(self): diff --git a/colossalai/context/process_group_initializer/initializer_1d.py b/colossalai/context/process_group_initializer/initializer_1d.py index a1b8948a0..4c0502804 100644 --- a/colossalai/context/process_group_initializer/initializer_1d.py +++ b/colossalai/context/process_group_initializer/initializer_1d.py @@ -36,6 +36,7 @@ class Initializer_1D(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_1D env.parallel_input_1d = False @@ -43,11 +44,13 @@ class Initializer_1D(ProcessGroupInitializer): for i in range(self.num_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_2d.py b/colossalai/context/process_group_initializer/initializer_2d.py index b46023b97..fe0ba553d 100644 --- a/colossalai/context/process_group_initializer/initializer_2d.py +++ b/colossalai/context/process_group_initializer/initializer_2d.py @@ -48,6 +48,7 @@ class Initializer_2D_Row(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_2D_ROW @@ -55,14 +56,16 @@ class Initializer_2D_Row(ProcessGroupInitializer): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode class Initializer_2D_Col(ProcessGroupInitializer): @@ -94,6 +97,7 @@ class Initializer_2D_Col(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_2D_COL @@ -101,14 +105,16 @@ class Initializer_2D_Col(ProcessGroupInitializer): for j in range(self.summa_dim): ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode @DIST_GROUP_INITIALIZER.register_module diff --git a/colossalai/context/process_group_initializer/initializer_2p5d.py b/colossalai/context/process_group_initializer/initializer_2p5d.py index b618cf00e..457361ab4 100644 --- a/colossalai/context/process_group_initializer/initializer_2p5d.py +++ b/colossalai/context/process_group_initializer/initializer_2p5d.py @@ -62,6 +62,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_2P5D_ROW @@ -73,14 +74,16 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): for i in range(self.tesseract_dim) ] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode class Initializer_2p5D_Col(ProcessGroupInitializer): @@ -115,6 +118,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_2P5D_COL @@ -126,14 +130,16 @@ class Initializer_2p5D_Col(ProcessGroupInitializer): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode class Initializer_2p5D_Dep(ProcessGroupInitializer): @@ -168,6 +174,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_2P5D_DEP @@ -179,14 +186,16 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer): for k in range(self.tesseract_dep) ] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode # i row j col k dep @@ -222,6 +231,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_2P5D_XZ @@ -233,14 +243,16 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer): for j in range(self.tesseract_dim) ] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode @DIST_GROUP_INITIALIZER.register_module diff --git a/colossalai/context/process_group_initializer/initializer_3d.py b/colossalai/context/process_group_initializer/initializer_3d.py index 8f0c0cd7b..0cda7a52d 100644 --- a/colossalai/context/process_group_initializer/initializer_3d.py +++ b/colossalai/context/process_group_initializer/initializer_3d.py @@ -52,6 +52,7 @@ class Initializer_3D_Input(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_INPUT env.input_group_3d = mode @@ -61,14 +62,16 @@ class Initializer_3D_Input(ProcessGroupInitializer): for k in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode class Initializer_3D_Weight(ProcessGroupInitializer): @@ -100,6 +103,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_WEIGHT env.weight_group_3d = mode @@ -109,14 +113,16 @@ class Initializer_3D_Weight(ProcessGroupInitializer): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode class Initializer_3D_Output(ProcessGroupInitializer): @@ -148,6 +154,7 @@ class Initializer_3D_Output(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.PARALLEL_3D_OUTPUT env.output_group_3d = mode @@ -157,14 +164,16 @@ class Initializer_3D_Output(ProcessGroupInitializer): for j in range(self.depth): ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode @DIST_GROUP_INITIALIZER.register_module diff --git a/colossalai/context/process_group_initializer/initializer_data.py b/colossalai/context/process_group_initializer/initializer_data.py index 6e3e5f7eb..aa4e54148 100644 --- a/colossalai/context/process_group_initializer/initializer_data.py +++ b/colossalai/context/process_group_initializer/initializer_data.py @@ -34,17 +34,20 @@ class Initializer_Data(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.DATA for i in range(self.num_data_parallel_group): ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_model.py b/colossalai/context/process_group_initializer/initializer_model.py index b57c3c7b5..99b9cc0d4 100644 --- a/colossalai/context/process_group_initializer/initializer_model.py +++ b/colossalai/context/process_group_initializer/initializer_model.py @@ -36,16 +36,20 @@ class Initializer_Model(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.MODEL for i in range(self.num_group): ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode diff --git a/colossalai/context/process_group_initializer/initializer_pipeline.py b/colossalai/context/process_group_initializer/initializer_pipeline.py index cc784bac2..edd1a3706 100644 --- a/colossalai/context/process_group_initializer/initializer_pipeline.py +++ b/colossalai/context/process_group_initializer/initializer_pipeline.py @@ -20,6 +20,7 @@ class Initializer_Pipeline(ProcessGroupInitializer): pipeline_parallel_size (int): Size of pipeline parallel tensor_parallel_size (int): Size of tensor parallel """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.data_group_size = self.world_size // self.data_parallel_size @@ -36,20 +37,19 @@ class Initializer_Pipeline(ProcessGroupInitializer): for i in range(self.data_parallel_size): for j in range(self.pipeline_stage_size): pipe_ranks = list( - range(i * self.data_group_size + j, - (i + 1) * self.data_group_size, - self.pipeline_stage_size)) + range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size)) pipe_group_size = len(pipe_ranks) pipe_group = dist.new_group(pipe_ranks) + group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group if self.rank in pipe_ranks: local_rank = pipe_ranks.index(self.rank) group_world_size = pipe_group_size process_group = pipe_group + cpu_group = group_cpu ranks_in_group = pipe_ranks dist_settings.append( - tuple((local_rank, group_world_size, - process_group, ranks_in_group, + tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.PIPELINE))) return dist_settings diff --git a/colossalai/context/process_group_initializer/initializer_sequence.py b/colossalai/context/process_group_initializer/initializer_sequence.py index bc1a6fa0a..5bf405a20 100644 --- a/colossalai/context/process_group_initializer/initializer_sequence.py +++ b/colossalai/context/process_group_initializer/initializer_sequence.py @@ -38,19 +38,23 @@ class Initializer_Sequence_DP(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.SEQUENCE_DP for i in range(self.num_group): ranks = [i * self.dp_size + j for j in range(self.dp_size)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode @DIST_GROUP_INITIALIZER.register_module @@ -86,10 +90,11 @@ class Initializer_Sequence(ProcessGroupInitializer): parallel_setting = [] - local_rank, group_world_size, process_group, ranks_in_group, mode = self._sequence_initializer.init_dist_group() + local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode = \ + self._sequence_initializer.init_dist_group() # change mode to sequence mode = ParallelMode.SEQUENCE - parallel_setting.append((local_rank, group_world_size, process_group, ranks_in_group, mode)) + parallel_setting.append((local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode)) parallel_setting.append(self._sequence_dp_initializer.init_dist_group()) return parallel_setting diff --git a/colossalai/context/process_group_initializer/initializer_tensor.py b/colossalai/context/process_group_initializer/initializer_tensor.py index 50add7adb..3724fc361 100644 --- a/colossalai/context/process_group_initializer/initializer_tensor.py +++ b/colossalai/context/process_group_initializer/initializer_tensor.py @@ -34,17 +34,20 @@ class Initializer_Tensor(ProcessGroupInitializer): local_rank = None ranks_in_group = None process_group = None + cpu_group = None group_world_size = None mode = ParallelMode.TENSOR for i in range(self.num_tensor_parallel_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] group = dist.new_group(ranks) + group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group + cpu_group = group_cpu ranks_in_group = ranks - return local_rank, group_world_size, process_group, ranks_in_group, mode + return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode