ColossalAI/colossalai/context/process_group_initializer/initializer_moe.py

120 lines
4.4 KiB
Python
Raw Normal View History

2022-01-07 07:08:36 +00:00
import torch.distributed as dist
from colossalai.registry import DIST_GROUP_INITIALIZER
from colossalai.global_variables import moe_env
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moemodel(ProcessGroupInitializer):
"""Model parallel initialization for MoE system.
2022-01-21 02:44:30 +00:00
:param moe_moel: Size of moe model parallel
:param moe_data: Size of moe data parallel
:param args: Args used in base class
:param kwargs: Kwargs used in base class
:type moe_model: int
:type moe_data: int
"""
2022-01-07 07:08:36 +00:00
def __init__(self, moe_model, moe_data, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_model
self.moe_data = moe_data
def init_dist_group(self):
"""Initialize model parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
2022-01-21 02:44:30 +00:00
:return: MoE model parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
2022-01-07 07:08:36 +00:00
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MOE_MODEL
for i in range(self.moe_data):
ranks = [i * self.moe_model + j for j in range(self.moe_model)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moedata(ProcessGroupInitializer):
"""Data parallel initialization for MoE system.
2022-01-21 02:44:30 +00:00
:param moe_moel: Size of moe model parallel
:param moe_data: Size of moe data parallel
:param args: Args used in base class
:param kwargs: Kwargs used in base class
:type moe_model: int
:type moe_data: int
"""
2022-01-07 07:08:36 +00:00
def __init__(self, moe_model, moe_data, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_model
self.moe_data = moe_data
def init_dist_group(self):
"""Initialize data parallel groups in moe parallel environment,
and assign local_ranks and groups to each gpu.
2022-01-21 02:44:30 +00:00
:return: MoE data parallelism's information
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
2022-01-07 07:08:36 +00:00
local_rank = None
ranks_in_group = None
process_group = None
group_world_size = None
mode = ParallelMode.MOE_DATA
for i in range(self.moe_model):
ranks = [i + j * self.moe_model for j in range(self.moe_data)]
group = dist.new_group(ranks)
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
ranks_in_group = ranks
return local_rank, group_world_size, process_group, ranks_in_group, mode
@DIST_GROUP_INITIALIZER.register_module
class Initializer_Moe(ProcessGroupInitializer):
"""Serves as the single entry point to MoE parallel initialization.
2022-01-21 02:44:30 +00:00
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
"""
2022-01-07 07:08:36 +00:00
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.moe_model = moe_env.model_parallel_size
self.moe_data = moe_env.data_parallel_size
self.model_initializer = Initializer_Moemodel(
self.moe_model, self.moe_data, *args, **kwargs)
self.data_initializer = Initializer_Moedata(
self.moe_model, self.moe_data, *args, **kwargs)
def init_dist_group(self):
"""Initializes MoE parallel communication groups.
2022-01-21 02:44:30 +00:00
:return: MoE parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
2022-01-07 07:08:36 +00:00
parallel_setting = [self.model_initializer.init_dist_group(),
self.data_initializer.init_dist_group()]
return parallel_setting