From ac7f45232b82e1d9345111f5b708bed84c538d19 Mon Sep 17 00:00:00 2001 From: Qu Wenwen Date: Wed, 27 Dec 2023 12:03:30 +0800 Subject: [PATCH] support sequence parallel --- internlm/core/scheduler/no_pipeline_scheduler.py | 6 ++++++ internlm/core/scheduler/pipeline_scheduler.py | 6 ++++++ internlm/initialize/launch.py | 1 - internlm/model/modeling_moe.py | 15 ++++++++++++++- internlm/moe/sharded_moe.py | 4 +++- 5 files changed, 29 insertions(+), 3 deletions(-) diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 79a6f62..f2d55be 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -6,7 +6,9 @@ from typing import Any, Callable, Iterable, List, Optional import torch +import torch.distributed as dist +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.utils.common import conditional_context @@ -125,6 +127,10 @@ class NonPipelineScheduler(BaseScheduler): if hasattr(gpc.config.model, "num_experts") else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) + # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, + # so we need to do allreduce + if gpc.config.parallel.sequence_parallel: + dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) moe_loss /= scale_loss loss /= scale_loss loss += moe_loss diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 5b864ff..9978835 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -311,6 +311,9 @@ class PipelineScheduler(BaseScheduler): if hasattr(gpc.config.model, "num_experts") else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) + # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce + if gpc.config.parallel.sequence_parallel: + dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) moe_loss /= self.num_microbatches accum_moe_loss.add_(moe_loss.detach()) @@ -858,6 +861,9 @@ class InterleavedPipelineScheduler(PipelineScheduler): if hasattr(gpc.config.model, "num_experts") else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) ) + # the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce + if gpc.config.parallel.sequence_parallel: + dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR)) moe_loss /= self.num_microbatches if self._accum_moe_loss is not None: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 491e2b0..523f838 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -369,7 +369,6 @@ def args_sanity_check(): not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param ), "not support overlap and moe at the same time" assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this" - assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now" def launch( diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index df6c7a8..2dc4f71 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal @@ -189,8 +189,18 @@ class PackedFlashBaseLayer1D(nn.Module): for _, param in self.mlp.moe_layer.experts.named_parameters(): if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) + for param in self.mlp.moe_layer.gate.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) set_fp32_attr_to_module(self.mlp.moe_layer.gate) + for param in self.norm1.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + for param in self.norm2.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init @@ -446,6 +456,9 @@ class PackedFlashInternLm1D(nn.Module): normal_(std=0.0052)(param) if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) + for param in self.norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) self.parallel_output = parallel_output def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): diff --git a/internlm/moe/sharded_moe.py b/internlm/moe/sharded_moe.py index 5d695ac..9ebac02 100644 --- a/internlm/moe/sharded_moe.py +++ b/internlm/moe/sharded_moe.py @@ -12,6 +12,8 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Module +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -189,7 +191,7 @@ def top1gating( # if we don't want to drop any tokens if not drop_tokens: new_capacity = torch.max(exp_counts).to(logits.device) - dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) + dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL)) capacity = new_capacity # Compute l_aux