import torch.distributed as dist from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.global_variables import moe_env from .process_group_initializer import ProcessGroupInitializer from ..parallel_mode import ParallelMode @DIST_GROUP_INITIALIZER.register_module class Initializer_Moemodel(ProcessGroupInitializer): """Model parallel initialization for MoE system. :param moe_moel: Size of moe model parallel :param moe_data: Size of moe data parallel :param args: Args used in base class :param kwargs: Kwargs used in base class :type moe_model: int :type moe_data: int """ def __init__(self, moe_model, moe_data, *args, **kwargs): super().__init__(*args, **kwargs) self.moe_model = moe_model self.moe_data = moe_data def init_dist_group(self): """Initialize model parallel groups in moe parallel environment, and assign local_ranks and groups to each gpu. :return: MoE model parallelism's information :rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) """ local_rank = None ranks_in_group = None process_group = None group_world_size = None mode = ParallelMode.MOE_MODEL for i in range(self.moe_data): ranks = [i * self.moe_model + j for j in range(self.moe_model)] group = dist.new_group(ranks) if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group ranks_in_group = ranks return local_rank, group_world_size, process_group, ranks_in_group, mode @DIST_GROUP_INITIALIZER.register_module class Initializer_Moedata(ProcessGroupInitializer): """Data parallel initialization for MoE system. :param moe_moel: Size of moe model parallel :param moe_data: Size of moe data parallel :param args: Args used in base class :param kwargs: Kwargs used in base class :type moe_model: int :type moe_data: int """ def __init__(self, moe_model, moe_data, *args, **kwargs): super().__init__(*args, **kwargs) self.moe_model = moe_model self.moe_data = moe_data def init_dist_group(self): """Initialize data parallel groups in moe parallel environment, and assign local_ranks and groups to each gpu. :return: MoE data parallelism's information :rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) """ local_rank = None ranks_in_group = None process_group = None group_world_size = None mode = ParallelMode.MOE_DATA for i in range(self.moe_model): ranks = [i + j * self.moe_model for j in range(self.moe_data)] group = dist.new_group(ranks) if self.rank in ranks: local_rank = ranks.index(self.rank) group_world_size = len(ranks) process_group = group ranks_in_group = ranks return local_rank, group_world_size, process_group, ranks_in_group, mode @DIST_GROUP_INITIALIZER.register_module class Initializer_Moe(ProcessGroupInitializer): """Serves as the single entry point to MoE parallel initialization. :param args: Args used to initialize ProcessGroupInitializer :param kwargs: Kwargs used to initialize ProcessGroupInitializer """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.moe_model = moe_env.model_parallel_size self.moe_data = moe_env.data_parallel_size self.model_initializer = Initializer_Moemodel( self.moe_model, self.moe_data, *args, **kwargs) self.data_initializer = Initializer_Moedata( self.moe_model, self.moe_data, *args, **kwargs) def init_dist_group(self): """Initializes MoE parallel communication groups. :return: MoE parallelism's information :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) """ parallel_setting = [self.model_initializer.init_dist_group(), self.data_initializer.init_dist_group()] return parallel_setting