[model checkpoint] add gloo groups for cpu tensor communication (#589)

pull/621/head
アマデウス 3 years ago committed by GitHub
parent 54e688b623
commit 297b8baae2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,6 +34,7 @@ class ParallelContext(metaclass=SingletonMeta):
self._local_ranks = dict()
self._world_sizes = dict()
self._groups = dict()
self._cpu_groups = dict()
self._ranks_in_group = dict()
# load config from file
@ -290,6 +291,32 @@ class ParallelContext(metaclass=SingletonMeta):
self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group
def get_cpu_group(self, parallel_mode: ParallelMode):
"""Returns the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
"""
self._check_parallel_mode(parallel_mode)
return self._cpu_groups[parallel_mode]
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._cpu_groups[parallel_mode] = group
def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
@ -335,13 +362,16 @@ class ParallelContext(metaclass=SingletonMeta):
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
ranks = list(range(world_size))
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None
self._register_dist(rank, world_size, None, cpu_group, ranks, ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
self.add_cpu_group(mode, cpu_group)
self.add_ranks_in_group(mode, ranks_in_group)
def check_sanity(self):

@ -36,6 +36,7 @@ class Initializer_1D(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
env.parallel_input_1d = False
@ -43,11 +44,13 @@ class Initializer_1D(ProcessGroupInitializer):
for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

@ -48,6 +48,7 @@ class Initializer_2D_Row(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2D_ROW
@ -55,14 +56,16 @@ class Initializer_2D_Row(ProcessGroupInitializer):
for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_2D_Col(ProcessGroupInitializer):
@ -94,6 +97,7 @@ class Initializer_2D_Col(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2D_COL
@ -101,14 +105,16 @@ class Initializer_2D_Col(ProcessGroupInitializer):
for j in range(self.summa_dim):
ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module

@ -62,6 +62,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_ROW
@ -73,14 +74,16 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
for i in range(self.tesseract_dim)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_2p5D_Col(ProcessGroupInitializer):
@ -115,6 +118,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_COL
@ -126,14 +130,16 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
for j in range(self.tesseract_dim)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_2p5D_Dep(ProcessGroupInitializer):
@ -168,6 +174,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_DEP
@ -179,14 +186,16 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
for k in range(self.tesseract_dep)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
# i row j col k dep
@ -222,6 +231,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_2P5D_XZ
@ -233,14 +243,16 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
for j in range(self.tesseract_dim)
]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module

@ -52,6 +52,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT
env.input_group_3d = mode
@ -61,14 +62,16 @@ class Initializer_3D_Input(ProcessGroupInitializer):
for k in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_Weight(ProcessGroupInitializer):
@ -100,6 +103,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT
env.weight_group_3d = mode
@ -109,14 +113,16 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
for j in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_3D_Output(ProcessGroupInitializer):
@ -148,6 +154,7 @@ class Initializer_3D_Output(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT
env.output_group_3d = mode
@ -157,14 +164,16 @@ class Initializer_3D_Output(ProcessGroupInitializer):
for j in range(self.depth):
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module

@ -34,17 +34,20 @@ class Initializer_Data(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.DATA
for i in range(self.num_data_parallel_group):
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

@ -36,16 +36,20 @@ class Initializer_Model(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.MODEL
for i in range(self.num_group):
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

@ -20,6 +20,7 @@ class Initializer_Pipeline(ProcessGroupInitializer):
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_group_size = self.world_size // self.data_parallel_size
@ -36,20 +37,19 @@ class Initializer_Pipeline(ProcessGroupInitializer):
for i in range(self.data_parallel_size):
for j in range(self.pipeline_stage_size):
pipe_ranks = list(
range(i * self.data_group_size + j,
(i + 1) * self.data_group_size,
self.pipeline_stage_size))
range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size))
pipe_group_size = len(pipe_ranks)
pipe_group = dist.new_group(pipe_ranks)
group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group
if self.rank in pipe_ranks:
local_rank = pipe_ranks.index(self.rank)
group_world_size = pipe_group_size
process_group = pipe_group
cpu_group = group_cpu
ranks_in_group = pipe_ranks
dist_settings.append(
tuple((local_rank, group_world_size,
process_group, ranks_in_group,
tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group,
ParallelMode.PIPELINE)))
return dist_settings

@ -38,19 +38,23 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.SEQUENCE_DP
for i in range(self.num_group):
ranks = [i * self.dp_size + j for j in range(self.dp_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
@ -86,10 +90,11 @@ class Initializer_Sequence(ProcessGroupInitializer):
parallel_setting = []
local_rank, group_world_size, process_group, ranks_in_group, mode = self._sequence_initializer.init_dist_group()
local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode = \
self._sequence_initializer.init_dist_group()
# change mode to sequence
mode = ParallelMode.SEQUENCE
parallel_setting.append((local_rank, group_world_size, process_group, ranks_in_group, mode))
parallel_setting.append((local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode))
parallel_setting.append(self._sequence_dp_initializer.init_dist_group())
return parallel_setting

@ -34,17 +34,20 @@ class Initializer_Tensor(ProcessGroupInitializer):
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.TENSOR
for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

Loading…
Cancel
Save