diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 09be064..0244b61 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -34,6 +34,9 @@ class ParallelMode(Enum): # expert parallel EXPERT = "expert" + # expert data parallel + EXPERT_DATA = "expert_data" + class ProcessGroupInitializer(ABC): """An object, knowing the parallelism configuration, that initializes parallel groups. @@ -400,3 +403,92 @@ class Initializer_Expert(ProcessGroupInitializer): ranks_in_group = ranks return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode + + +class Initializer_Expert_Data(ProcessGroupInitializer): + """A ProcessGroupInitializer for zero-1 parallelism. + + Args: + rank (int): The rank of current process. + world_size (int): Size of whole communication world. + data_parallel_size (int): Size of data parallel. + pipeline_parallel_size (int): Size of pipeline parallel. + tensor_parallel_size (int): Size of tensor parallel. + zero1_parallel_size (int): Size of zero-1 parallel. + expert_parallel_size (int): Size of expert parallel. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_expert_parallel_group = self.world_size // self.expert_parallel_size + + assert self.world_size % self.rank_num_per_expert_group == 0 + + def _get_expert_parallel_ranks(self): + """ + Create expert and data parallel groups + Example: world_size = 8, model_parallel_size = 2, expert_parallel_size = 2 + model_parallel_group = [0,1], [2,3], [4,5], [6,7] + data_parallel_group = [0,2,4,6], [1,3,5,7] + expert_parallel_group = [0,2], [4,6], [1,3], [5,7] + expert_data_parallel_group = [0,4], [2,6], [1,5], [3,7] + """ + data_parallel_groups = [] + for i in range(self.model_parallel_size): + data_parallel_groups.append(list(range(i, self.world_size, self.model_parallel_size))) + + expert_parallel_groups = [] + expert_data_parallel_groups = [] + for dp_ranks in range(self.num_expert_parallel_group): + # partition of expert parallel group, e.g. [0,2], [4,6] + part_ep_group = [] + for i in range(0, self.data_parallel_size, self.expert_parallel_size): + part_ep_group.append(dp_ranks[i : i + self.expert_parallel_size]) + expert_data_parallel_groups.extend(part_ep_group) + + for expert_dp_ranks in zip(*part_ep_group): + expert_data_parallel_groups.append(list(expert_dp_ranks)) + + return expert_parallel_groups, expert_data_parallel_groups + + def init_dist_group(self, use_cpu: bool = False): + """Initialize expert parallel and expert data groups, and assign local_ranks and groups to each gpu. + + Returns: + list: [(local_rank, group_world_size, process_group, ranks_in_group, mode), ...]: + A length 2 list consists of expert parallelism's and expert data parallelism's information tuple. + """ + expert_parallel_groups, expert_data_parallel_groups = self._get_expert_parallel_ranks() + + groups = [] + for ranks in expert_parallel_groups: + group = dist.new_group(ranks) + if use_cpu: + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + else: + group_cpu = None + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + groups.append((local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EXPERT)) + + for ranks in expert_data_parallel_groups: + group = dist.new_group(ranks) + if use_cpu: + group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + else: + group_cpu = None + if self.rank in ranks: + local_rank = ranks.index(self.rank) + group_world_size = len(ranks) + process_group = group + cpu_group = group_cpu + ranks_in_group = ranks + groups.append( + (local_rank, group_world_size, process_group, cpu_group, ranks_in_group, ParallelMode.EXPERT_DATA) + ) + + return groups diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 1bf3499..397b9d2 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -492,7 +492,7 @@ class HybridZeroOptimizer(BaseOptimizer): if not self._overlap_communication: for group_id in range(len(self._fp16_param_groups)): for param in self._fp16_param_groups[group_id]: - if param.grad is not None: + if param.grad is not None and not is_moe_param(param): self._store_and_try_reduce_grads_by_bucket(param) # we need to reduce the gradients left in the communication bucket diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 63b190d..5a9e4c6 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -30,6 +30,32 @@ def sync_model_param(model, parallel_mode): dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) +def sync_tensor(tensor, parallel_mode): + r"""Make sure data tensor(parameters) are consistent during Data and Expert Parallel Mode. + + Args: + tensor (:class:`torch.Tensor`): A parameters you check the consistency. + parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked. + """ + if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: + ranks = gpc.get_ranks_in_group(parallel_mode) + dist.broadcast(tensor, src=ranks[0], group=gpc.get_group(parallel_mode)) + + +# TODO: will be used in expert data parallel, may can also used in sync_model_param_within_tp +def sync_model_param_within_ep(model): + r"""Make sure data parameters are consistent during Data Parallel Mode. + + Args: + model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. + """ + for param in model.parameters(): + if is_moe_param(param): + sync_tensor(param, ParallelMode.EXPERT_DATA) + else: + sync_tensor(param, ParallelMode.DATA) + + def sync_model_param_within_tp(model): r"""This function is changed from colossalai, which is ``sync_model_param``.