mirror of https://github.com/InternLM/InternLM
support sequence parallel
parent
513ebb9c3a
commit
ac7f45232b
|
@ -6,7 +6,9 @@
|
|||
from typing import Any, Callable, Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
|
@ -125,6 +127,10 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
if hasattr(gpc.config.model, "num_experts")
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
|
||||
# so we need to do allreduce
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
moe_loss /= scale_loss
|
||||
loss /= scale_loss
|
||||
loss += moe_loss
|
||||
|
|
|
@ -311,6 +311,9 @@ class PipelineScheduler(BaseScheduler):
|
|||
if hasattr(gpc.config.model, "num_experts")
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
moe_loss /= self.num_microbatches
|
||||
accum_moe_loss.add_(moe_loss.detach())
|
||||
|
||||
|
@ -858,6 +861,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
if hasattr(gpc.config.model, "num_experts")
|
||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||
)
|
||||
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
|
||||
moe_loss /= self.num_microbatches
|
||||
|
||||
if self._accum_moe_loss is not None:
|
||||
|
|
|
@ -369,7 +369,6 @@ def args_sanity_check():
|
|||
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||
), "not support overlap and moe at the same time"
|
||||
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
|
||||
assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now"
|
||||
|
||||
|
||||
def launch(
|
||||
|
|
|
@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
|||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.core.naive_amp import set_fp32_attr_to_module
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
|
@ -189,8 +189,18 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
for _, param in self.mlp.moe_layer.experts.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
for param in self.mlp.moe_layer.gate.parameters():
|
||||
if gpc.config.parallel.sequence_parallel is True:
|
||||
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||
set_fp32_attr_to_module(self.mlp.moe_layer.gate)
|
||||
|
||||
for param in self.norm1.parameters():
|
||||
if gpc.config.parallel.sequence_parallel is True:
|
||||
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||
for param in self.norm2.parameters():
|
||||
if gpc.config.parallel.sequence_parallel is True:
|
||||
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||
|
||||
self.dropout2 = nn.Dropout(drop_rate)
|
||||
self.use_swiglu = use_swiglu
|
||||
self.use_scaled_init = use_scaled_init
|
||||
|
@ -446,6 +456,9 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
normal_(std=0.0052)(param)
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
for param in self.norm.parameters():
|
||||
if gpc.config.parallel.sequence_parallel is True:
|
||||
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||
|
|
|
@ -12,6 +12,8 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
|
||||
|
@ -189,7 +191,7 @@ def top1gating(
|
|||
# if we don't want to drop any tokens
|
||||
if not drop_tokens:
|
||||
new_capacity = torch.max(exp_counts).to(logits.device)
|
||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
|
||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
|
||||
capacity = new_capacity
|
||||
|
||||
# Compute l_aux
|
||||
|
|
Loading…
Reference in New Issue