mirror of https://github.com/hpcaitech/ColossalAI
[model checkpoint] add gloo groups for cpu tensor communication (#589)
parent
54e688b623
commit
297b8baae2
|
@ -34,6 +34,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
self._local_ranks = dict()
|
||||
self._world_sizes = dict()
|
||||
self._groups = dict()
|
||||
self._cpu_groups = dict()
|
||||
self._ranks_in_group = dict()
|
||||
|
||||
# load config from file
|
||||
|
@ -290,6 +291,32 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
self._check_parallel_mode(parallel_mode)
|
||||
self._groups[parallel_mode] = group
|
||||
|
||||
def get_cpu_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the Gloo group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
:return: The group of the current device for `parallel_mode`
|
||||
:rtype: torch.distributed.ProcessGroup
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
return self._cpu_groups[parallel_mode]
|
||||
|
||||
def add_cpu_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
|
||||
"""Adds the Gloo group of the current device for `parallel_mode`.
|
||||
|
||||
:param parallel_mode: The chosen parallel mode
|
||||
:type parallel_mode: :class:`colossalai.context.ParallelMode`
|
||||
:param group: The group to be added
|
||||
:type group: torch.distributed.ProcessGroup
|
||||
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
|
||||
of :class:`colossalai.context.ParallelMode`
|
||||
"""
|
||||
self._check_parallel_mode(parallel_mode)
|
||||
self._cpu_groups[parallel_mode] = group
|
||||
|
||||
def get_ranks_in_group(self, parallel_mode: ParallelMode):
|
||||
"""Returns the rank of the current device for `parallel_mode` in the group.
|
||||
|
||||
|
@ -335,13 +362,16 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
self._register_dist(rank, world_size, None, list(range(world_size)), ParallelMode.GLOBAL)
|
||||
ranks = list(range(world_size))
|
||||
cpu_group = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else None
|
||||
self._register_dist(rank, world_size, None, cpu_group, ranks, ParallelMode.GLOBAL)
|
||||
self.add_global_rank(ParallelMode.GLOBAL, rank)
|
||||
|
||||
def _register_dist(self, local_rank, world_size, process_group, ranks_in_group, mode):
|
||||
def _register_dist(self, local_rank, world_size, process_group, cpu_group, ranks_in_group, mode):
|
||||
self.add_local_rank(mode, local_rank)
|
||||
self.add_world_size(mode, world_size)
|
||||
self.add_group(mode, process_group)
|
||||
self.add_cpu_group(mode, cpu_group)
|
||||
self.add_ranks_in_group(mode, ranks_in_group)
|
||||
|
||||
def check_sanity(self):
|
||||
|
|
|
@ -36,6 +36,7 @@ class Initializer_1D(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_1D
|
||||
env.parallel_input_1d = False
|
||||
|
@ -43,11 +44,13 @@ class Initializer_1D(ProcessGroupInitializer):
|
|||
for i in range(self.num_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
|
|
@ -48,6 +48,7 @@ class Initializer_2D_Row(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2D_ROW
|
||||
|
||||
|
@ -55,14 +56,16 @@ class Initializer_2D_Row(ProcessGroupInitializer):
|
|||
for j in range(self.summa_dim):
|
||||
ranks = [i * self.tensor_parallel_size + j * self.summa_dim + k for k in range(self.summa_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_2D_Col(ProcessGroupInitializer):
|
||||
|
@ -94,6 +97,7 @@ class Initializer_2D_Col(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2D_COL
|
||||
|
||||
|
@ -101,14 +105,16 @@ class Initializer_2D_Col(ProcessGroupInitializer):
|
|||
for j in range(self.summa_dim):
|
||||
ranks = [i * self.tensor_parallel_size + j + k * self.summa_dim for k in range(self.summa_dim)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
|
|
|
@ -62,6 +62,7 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_ROW
|
||||
|
||||
|
@ -73,14 +74,16 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
|
|||
for i in range(self.tesseract_dim)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_2p5D_Col(ProcessGroupInitializer):
|
||||
|
@ -115,6 +118,7 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_COL
|
||||
|
||||
|
@ -126,14 +130,16 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
|
|||
for j in range(self.tesseract_dim)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
||||
|
@ -168,6 +174,7 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_DEP
|
||||
|
||||
|
@ -179,14 +186,16 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
|
|||
for k in range(self.tesseract_dep)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
# i row j col k dep
|
||||
|
@ -222,6 +231,7 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_2P5D_XZ
|
||||
|
||||
|
@ -233,14 +243,16 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
|
|||
for j in range(self.tesseract_dim)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
|
|
|
@ -52,6 +52,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_INPUT
|
||||
env.input_group_3d = mode
|
||||
|
@ -61,14 +62,16 @@ class Initializer_3D_Input(ProcessGroupInitializer):
|
|||
for k in range(self.depth):
|
||||
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_3D_Weight(ProcessGroupInitializer):
|
||||
|
@ -100,6 +103,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_WEIGHT
|
||||
env.weight_group_3d = mode
|
||||
|
@ -109,14 +113,16 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
|
|||
for j in range(self.depth):
|
||||
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
class Initializer_3D_Output(ProcessGroupInitializer):
|
||||
|
@ -148,6 +154,7 @@ class Initializer_3D_Output(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.PARALLEL_3D_OUTPUT
|
||||
env.output_group_3d = mode
|
||||
|
@ -157,14 +164,16 @@ class Initializer_3D_Output(ProcessGroupInitializer):
|
|||
for j in range(self.depth):
|
||||
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
|
|
|
@ -34,17 +34,20 @@ class Initializer_Data(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.DATA
|
||||
|
||||
for i in range(self.num_data_parallel_group):
|
||||
ranks = [i + j * self.num_data_parallel_group for j in range(self.data_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
|
|
@ -36,16 +36,20 @@ class Initializer_Model(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.MODEL
|
||||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.model_parallel_size + j for j in range(self.model_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
|
|
@ -20,6 +20,7 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
|||
pipeline_parallel_size (int): Size of pipeline parallel
|
||||
tensor_parallel_size (int): Size of tensor parallel
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.data_group_size = self.world_size // self.data_parallel_size
|
||||
|
@ -36,20 +37,19 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
|||
for i in range(self.data_parallel_size):
|
||||
for j in range(self.pipeline_stage_size):
|
||||
pipe_ranks = list(
|
||||
range(i * self.data_group_size + j,
|
||||
(i + 1) * self.data_group_size,
|
||||
self.pipeline_stage_size))
|
||||
range(i * self.data_group_size + j, (i + 1) * self.data_group_size, self.pipeline_stage_size))
|
||||
pipe_group_size = len(pipe_ranks)
|
||||
pipe_group = dist.new_group(pipe_ranks)
|
||||
group_cpu = dist.new_group(pipe_ranks, backend='gloo') if dist.get_backend() != 'gloo' else pipe_group
|
||||
|
||||
if self.rank in pipe_ranks:
|
||||
local_rank = pipe_ranks.index(self.rank)
|
||||
group_world_size = pipe_group_size
|
||||
process_group = pipe_group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = pipe_ranks
|
||||
dist_settings.append(
|
||||
tuple((local_rank, group_world_size,
|
||||
process_group, ranks_in_group,
|
||||
tuple((local_rank, group_world_size, process_group, cpu_group, ranks_in_group,
|
||||
ParallelMode.PIPELINE)))
|
||||
|
||||
return dist_settings
|
||||
|
|
|
@ -38,19 +38,23 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.SEQUENCE_DP
|
||||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.dp_size + j for j in range(self.dp_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
||||
|
||||
@DIST_GROUP_INITIALIZER.register_module
|
||||
|
@ -86,10 +90,11 @@ class Initializer_Sequence(ProcessGroupInitializer):
|
|||
|
||||
parallel_setting = []
|
||||
|
||||
local_rank, group_world_size, process_group, ranks_in_group, mode = self._sequence_initializer.init_dist_group()
|
||||
local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode = \
|
||||
self._sequence_initializer.init_dist_group()
|
||||
# change mode to sequence
|
||||
mode = ParallelMode.SEQUENCE
|
||||
|
||||
parallel_setting.append((local_rank, group_world_size, process_group, ranks_in_group, mode))
|
||||
parallel_setting.append((local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode))
|
||||
parallel_setting.append(self._sequence_dp_initializer.init_dist_group())
|
||||
return parallel_setting
|
||||
|
|
|
@ -34,17 +34,20 @@ class Initializer_Tensor(ProcessGroupInitializer):
|
|||
local_rank = None
|
||||
ranks_in_group = None
|
||||
process_group = None
|
||||
cpu_group = None
|
||||
group_world_size = None
|
||||
mode = ParallelMode.TENSOR
|
||||
|
||||
for i in range(self.num_tensor_parallel_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
||||
|
||||
if self.rank in ranks:
|
||||
local_rank = ranks.index(self.rank)
|
||||
group_world_size = len(ranks)
|
||||
process_group = group
|
||||
cpu_group = group_cpu
|
||||
ranks_in_group = ranks
|
||||
|
||||
return local_rank, group_world_size, process_group, ranks_in_group, mode
|
||||
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
||||
|
|
Loading…
Reference in New Issue