support sequence parallel

pull/562/head
Qu Wenwen 2023-12-27 12:03:30 +08:00
parent 513ebb9c3a
commit ac7f45232b
5 changed files with 29 additions and 3 deletions

View File

@ -6,7 +6,9 @@
from typing import Any, Callable, Iterable, List, Optional
import torch
import torch.distributed as dist
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.utils.common import conditional_context
@ -125,6 +127,10 @@ class NonPipelineScheduler(BaseScheduler):
if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= scale_loss
loss /= scale_loss
loss += moe_loss

View File

@ -311,6 +311,9 @@ class PipelineScheduler(BaseScheduler):
if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
@ -858,6 +861,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
if self._accum_moe_loss is not None:

View File

@ -369,7 +369,6 @@ def args_sanity_check():
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
), "not support overlap and moe at the same time"
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now"
def launch(

View File

@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_fp32_attr_to_module
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
@ -189,8 +189,18 @@ class PackedFlashBaseLayer1D(nn.Module):
for _, param in self.mlp.moe_layer.experts.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
for param in self.mlp.moe_layer.gate.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
set_fp32_attr_to_module(self.mlp.moe_layer.gate)
for param in self.norm1.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
for param in self.norm2.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init
@ -446,6 +456,9 @@ class PackedFlashInternLm1D(nn.Module):
normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True)
for param in self.norm.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):

View File

@ -12,6 +12,8 @@ import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
@ -189,7 +191,7 @@ def top1gating(
# if we don't want to drop any tokens
if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
capacity = new_capacity
# Compute l_aux