2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
2022-01-20 05:44:51 +00:00
|
|
|
import torch.distributed as dist
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
from colossalai.registry import DIST_GROUP_INITIALIZER
|
2023-02-15 02:53:38 +00:00
|
|
|
|
|
|
|
from ..parallel_mode import ParallelMode
|
2021-10-28 16:21:23 +00:00
|
|
|
from .initializer_tensor import Initializer_Tensor
|
|
|
|
from .process_group_initializer import ProcessGroupInitializer
|
|
|
|
|
|
|
|
|
2022-01-20 05:44:51 +00:00
|
|
|
@DIST_GROUP_INITIALIZER.register_module
|
|
|
|
class Initializer_Sequence_DP(ProcessGroupInitializer):
|
2022-01-21 02:44:30 +00:00
|
|
|
"""A ProcessGroupInitializer for sequence parallelism all-reduce.
|
2022-01-20 05:44:51 +00:00
|
|
|
|
2022-01-21 02:44:30 +00:00
|
|
|
In Sequence Parallelism, each GPU holds the full copy of model weights,
|
2022-01-20 05:44:51 +00:00
|
|
|
thus, gradient all-reduce occurs across all processes in the same pipeline stage
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
rank (int): The rank of current process
|
|
|
|
world_size (int): Size of whole communication world
|
|
|
|
config (Config): Running configuration
|
|
|
|
data_parallel_size (int): Size of data parallel
|
|
|
|
pipeline_parallel_size (int): Size of pipeline parallel
|
|
|
|
tensor_parallel_size (int): Size of tensor parallel
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2022-01-20 05:44:51 +00:00
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.dp_size = self.world_size // self.pipeline_parallel_size
|
|
|
|
self.num_group = self.pipeline_parallel_size
|
|
|
|
|
|
|
|
def init_dist_group(self):
|
2022-01-21 02:44:30 +00:00
|
|
|
"""Initialize Sequence Parallel process groups used for gradient all-reduce.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Returns:
|
|
|
|
Tuple: A tuple (local_rank, group_world_size, process_group, ranks_in_group, mode).
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2022-01-20 05:44:51 +00:00
|
|
|
local_rank = None
|
|
|
|
ranks_in_group = None
|
|
|
|
process_group = None
|
2022-04-01 02:15:52 +00:00
|
|
|
cpu_group = None
|
2022-01-20 05:44:51 +00:00
|
|
|
group_world_size = None
|
|
|
|
mode = ParallelMode.SEQUENCE_DP
|
|
|
|
|
|
|
|
for i in range(self.num_group):
|
|
|
|
ranks = [i * self.dp_size + j for j in range(self.dp_size)]
|
|
|
|
group = dist.new_group(ranks)
|
2022-04-01 02:15:52 +00:00
|
|
|
group_cpu = dist.new_group(ranks, backend='gloo') if dist.get_backend() != 'gloo' else group
|
2022-01-20 05:44:51 +00:00
|
|
|
|
|
|
|
if self.rank in ranks:
|
|
|
|
local_rank = ranks.index(self.rank)
|
|
|
|
group_world_size = len(ranks)
|
|
|
|
process_group = group
|
2022-04-01 02:15:52 +00:00
|
|
|
cpu_group = group_cpu
|
2022-01-20 05:44:51 +00:00
|
|
|
ranks_in_group = ranks
|
2022-04-01 02:15:52 +00:00
|
|
|
|
|
|
|
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
|
2022-01-20 05:44:51 +00:00
|
|
|
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
@DIST_GROUP_INITIALIZER.register_module
|
|
|
|
class Initializer_Sequence(ProcessGroupInitializer):
|
2022-01-21 02:44:30 +00:00
|
|
|
"""A ProcessGroupInitializer for sequence parallelism.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
rank (int): The rank of current process.
|
|
|
|
world_size (int): Size of whole communication world.
|
|
|
|
config (Config): Running configuration.
|
|
|
|
data_parallel_size (int): Size of data parallel.
|
|
|
|
pipeline_parallel_size (int): Size of pipeline parallel.
|
|
|
|
tensor_parallel_size (int): Size of tensor parallel.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2022-04-02 06:30:04 +00:00
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
2021-10-28 16:21:23 +00:00
|
|
|
super().__init__(*args, **kwargs)
|
2022-01-20 05:44:51 +00:00
|
|
|
# reuse tensor parallel initializer code
|
|
|
|
self._sequence_initializer = Initializer_Tensor(*args, **kwargs)
|
|
|
|
self._sequence_dp_initializer = Initializer_Sequence_DP(*args, **kwargs)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
def init_dist_group(self):
|
2022-01-21 02:44:30 +00:00
|
|
|
"""Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu.
|
2022-01-20 05:44:51 +00:00
|
|
|
|
|
|
|
Sequence parallelism requires 2 process groups. The first is for model forward where several processes
|
2022-03-25 05:02:39 +00:00
|
|
|
exchange partial query, key and value embedding to compute self attention values. The second is for
|
2022-01-20 05:44:51 +00:00
|
|
|
all-reduce to synchronize the model parameters.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Returns:
|
|
|
|
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
|
|
|
|
A Sequence parallelism's information in list of tuples.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2022-01-20 05:44:51 +00:00
|
|
|
|
|
|
|
parallel_setting = []
|
|
|
|
|
2022-04-01 02:15:52 +00:00
|
|
|
local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode = \
|
|
|
|
self._sequence_initializer.init_dist_group()
|
2021-10-28 16:21:23 +00:00
|
|
|
# change mode to sequence
|
|
|
|
mode = ParallelMode.SEQUENCE
|
2022-01-20 05:44:51 +00:00
|
|
|
|
2022-04-01 02:15:52 +00:00
|
|
|
parallel_setting.append((local_rank, group_world_size, process_group, cpu_grop, ranks_in_group, mode))
|
2022-01-20 05:44:51 +00:00
|
|
|
parallel_setting.append(self._sequence_dp_initializer.init_dist_group())
|
|
|
|
return parallel_setting
|