mirror of https://github.com/InternLM/InternLM
support sequence parallel
parent
513ebb9c3a
commit
ac7f45232b
|
@ -6,7 +6,9 @@
|
||||||
from typing import Any, Callable, Iterable, List, Optional
|
from typing import Any, Callable, Iterable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.utils.common import conditional_context
|
from internlm.utils.common import conditional_context
|
||||||
|
@ -125,6 +127,10 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
if hasattr(gpc.config.model, "num_experts")
|
if hasattr(gpc.config.model, "num_experts")
|
||||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||||
)
|
)
|
||||||
|
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
|
||||||
|
# so we need to do allreduce
|
||||||
|
if gpc.config.parallel.sequence_parallel:
|
||||||
|
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
|
||||||
moe_loss /= scale_loss
|
moe_loss /= scale_loss
|
||||||
loss /= scale_loss
|
loss /= scale_loss
|
||||||
loss += moe_loss
|
loss += moe_loss
|
||||||
|
|
|
@ -311,6 +311,9 @@ class PipelineScheduler(BaseScheduler):
|
||||||
if hasattr(gpc.config.model, "num_experts")
|
if hasattr(gpc.config.model, "num_experts")
|
||||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||||
)
|
)
|
||||||
|
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
|
||||||
|
if gpc.config.parallel.sequence_parallel:
|
||||||
|
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
|
||||||
moe_loss /= self.num_microbatches
|
moe_loss /= self.num_microbatches
|
||||||
accum_moe_loss.add_(moe_loss.detach())
|
accum_moe_loss.add_(moe_loss.detach())
|
||||||
|
|
||||||
|
@ -858,6 +861,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
if hasattr(gpc.config.model, "num_experts")
|
if hasattr(gpc.config.model, "num_experts")
|
||||||
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
else torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
|
||||||
)
|
)
|
||||||
|
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
|
||||||
|
if gpc.config.parallel.sequence_parallel:
|
||||||
|
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
|
||||||
moe_loss /= self.num_microbatches
|
moe_loss /= self.num_microbatches
|
||||||
|
|
||||||
if self._accum_moe_loss is not None:
|
if self._accum_moe_loss is not None:
|
||||||
|
|
|
@ -369,7 +369,6 @@ def args_sanity_check():
|
||||||
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
|
||||||
), "not support overlap and moe at the same time"
|
), "not support overlap and moe at the same time"
|
||||||
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
|
assert gpc.config.parallel.zero1.size == -1, "moe only support zero1, set zero1=dict(size=-1,...) can fix this"
|
||||||
assert not gpc.config.parallel.sequence_parallel, "moe not support sequence parallel for now"
|
|
||||||
|
|
||||||
|
|
||||||
def launch(
|
def launch(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.core.naive_amp import set_fp32_attr_to_module
|
from internlm.core.naive_amp import set_fp32_attr_to_module
|
||||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||||
|
@ -189,8 +189,18 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
for _, param in self.mlp.moe_layer.experts.named_parameters():
|
for _, param in self.mlp.moe_layer.experts.named_parameters():
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
|
for param in self.mlp.moe_layer.gate.parameters():
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||||
set_fp32_attr_to_module(self.mlp.moe_layer.gate)
|
set_fp32_attr_to_module(self.mlp.moe_layer.gate)
|
||||||
|
|
||||||
|
for param in self.norm1.parameters():
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||||
|
for param in self.norm2.parameters():
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||||
|
|
||||||
self.dropout2 = nn.Dropout(drop_rate)
|
self.dropout2 = nn.Dropout(drop_rate)
|
||||||
self.use_swiglu = use_swiglu
|
self.use_swiglu = use_swiglu
|
||||||
self.use_scaled_init = use_scaled_init
|
self.use_scaled_init = use_scaled_init
|
||||||
|
@ -446,6 +456,9 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
normal_(std=0.0052)(param)
|
normal_(std=0.0052)(param)
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
|
for param in self.norm.parameters():
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||||
self.parallel_output = parallel_output
|
self.parallel_output = parallel_output
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
|
|
|
@ -12,6 +12,8 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from internlm.core.context import ParallelMode
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||||
|
|
||||||
|
@ -189,7 +191,7 @@ def top1gating(
|
||||||
# if we don't want to drop any tokens
|
# if we don't want to drop any tokens
|
||||||
if not drop_tokens:
|
if not drop_tokens:
|
||||||
new_capacity = torch.max(exp_counts).to(logits.device)
|
new_capacity = torch.max(exp_counts).to(logits.device)
|
||||||
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=dist.get_world_group())
|
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
|
||||||
capacity = new_capacity
|
capacity = new_capacity
|
||||||
|
|
||||||
# Compute l_aux
|
# Compute l_aux
|
||||||
|
|
Loading…
Reference in New Issue