support sequence parallel

pull/562/head
Qu Wenwen 2023-12-27 12:03:30 +08:00
parent 513ebb9c3a
commit ac7f45232b
5 changed files with 29 additions and 3 deletions

View File

@ -6,7 +6,9 @@
from typing import Any, Callable, Iterable, List, Optional from typing import Any, Callable, Iterable, List, Optional
import torch import torch
import torch.distributed as dist
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine from internlm.core.engine import Engine
from internlm.utils.common import conditional_context from internlm.utils.common import conditional_context
@ -125,6 +127,10 @@ class NonPipelineScheduler(BaseScheduler):
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= scale_loss moe_loss /= scale_loss
loss /= scale_loss loss /= scale_loss
loss += moe_loss loss += moe_loss

View File

@ -311,6 +311,9 @@ class PipelineScheduler(BaseScheduler):
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach()) accum_moe_loss.add_(moe_loss.detach())
@ -858,6 +861,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
if hasattr(gpc.config.model, "num_experts") if hasattr(gpc.config.model, "num_experts")
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype")) else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
) )
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches moe_loss /= self.num_microbatches
if self._accum_moe_loss is not None: if self._accum_moe_loss is not None:

View File

@ -369,7 +369,6 @@ def args_sanity_check():
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
), "not support overlap and moe at the same time" ), "not support overlap and moe at the same time"
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this" assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now"
def launch( def launch(

View File

@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc from internlm.core.context.parallel_context import global_context as gpc
from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.core.naive_amp import set_fp32_attr_to_module
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
@ -189,8 +189,18 @@ class PackedFlashBaseLayer1D(nn.Module):
for _, param in self.mlp.moe_layer.experts.named_parameters(): for _, param in self.mlp.moe_layer.experts.named_parameters():
if gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True) setattr(param, IS_TENSOR_PARALLEL, True)
for param in self.mlp.moe_layer.gate.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
set_fp32_attr_to_module(self.mlp.moe_layer.gate) set_fp32_attr_to_module(self.mlp.moe_layer.gate)
for param in self.norm1.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
for param in self.norm2.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
self.dropout2 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate)
self.use_swiglu = use_swiglu self.use_swiglu = use_swiglu
self.use_scaled_init = use_scaled_init self.use_scaled_init = use_scaled_init
@ -446,6 +456,9 @@ class PackedFlashInternLm1D(nn.Module):
normal_(std=0.0052)(param) normal_(std=0.0052)(param)
if gpc.get_world_size(ParallelMode.TENSOR) > 1: if gpc.get_world_size(ParallelMode.TENSOR) > 1:
setattr(param, IS_TENSOR_PARALLEL, True) setattr(param, IS_TENSOR_PARALLEL, True)
for param in self.norm.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
self.parallel_output = parallel_output self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):

View File

@ -12,6 +12,8 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
@ -189,7 +191,7 @@ def top1gating(
# if we don't want to drop any tokens # if we don't want to drop any tokens
if not drop_tokens: if not drop_tokens:
new_capacity = torch.max(exp_counts).to(logits.device) new_capacity = torch.max(exp_counts).to(logits.device)
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group()) dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
capacity = new_capacity capacity = new_capacity
# Compute l_aux # Compute l_aux