From 1d7e2d04ec77a7ae2575fa1df07cee3737d7b540 Mon Sep 17 00:00:00 2001 From: ytxiong <45058324+yingtongxiong@users.noreply.github.com> Date: Wed, 25 Oct 2023 14:16:32 +0800 Subject: [PATCH] fix(*)/all-reduce for norm in sequence parallel (#443) * fix all-reduce norm grad * change the order of dp and sp all-reduce * fix lint --- internlm/core/context/__init__.py | 2 ++ internlm/core/context/parallel_context.py | 1 + internlm/model/modeling_internlm.py | 12 +++++++++- .../solver/optimizer/hybrid_zero_optim.py | 23 ++++++++++++++++++- 4 files changed, 36 insertions(+), 2 deletions(-) diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index 5cbb832..6f1142c 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -1,4 +1,5 @@ from .parallel_context import ( + IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, Config, ParallelContext, @@ -29,6 +30,7 @@ from .random import ( __all__ = [ "Config", "IS_TENSOR_PARALLEL", + "IS_SEQUENCE_PARALLEL", "global_context", "ParallelContext", "ParallelMode", diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 915905a..633dfe4 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -25,6 +25,7 @@ from .process_group_initializer import ParallelMode from .random import add_seed, get_seeds, set_mode IS_TENSOR_PARALLEL = "is_tensor_parallel" +IS_SEQUENCE_PARALLEL = "is_sequence_parallel" logger = get_logger(__file__) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 2856a78..cbf425c 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D @@ -134,6 +134,12 @@ class PackedFlashBaseLayer1D(nn.Module): for _, param in self.mlp.named_parameters(): if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) + for param in self.norm1.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + for param in self.norm2.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu @@ -356,6 +362,10 @@ class PackedFlashInternLm1D(nn.Module): normal_(std=0.0052)(param) if gpc.get_world_size(ParallelMode.TENSOR) > 1: setattr(param, IS_TENSOR_PARALLEL, True) + for param in self.norm.parameters(): + if gpc.config.parallel.sequence_parallel is True: + setattr(param, IS_SEQUENCE_PARALLEL, True) + self.parallel_output = parallel_output def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index e3f9608..2817258 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist from torch.optim import Optimizer -from internlm.core.context import Config, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, Config, ParallelMode from internlm.core.context import global_context as gpc from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( @@ -297,6 +297,15 @@ class HybridZeroOptimizer(BaseOptimizer): reduce_rank=reduce_rank, ) + def reduction_sp_func(): + handle = reduce_tensor( + param.grad, + dtype=None, + dst_rank=reduce_rank, + parallel_mode=ParallelMode.TENSOR, + ) + handle.wait() + # define hook # NOT IMPORTANT BUT GOOD TO KNOW: # args here is not grad, but allow_unreacable and accumulate_grad @@ -304,6 +313,18 @@ class HybridZeroOptimizer(BaseOptimizer): if self.skip_grad_reduce is False: reduction_func() + # define hook for sequence_parallel + def reduce_grad_hook_sp(*args): # pylint: disable=W0613 + if self.skip_grad_reduce is False: + reduction_sp_func() + + # if sequence_parallel is True, + # the grad of norm should be all-reduce across the tp process group + if gpc.config.parallel.sequence_parallel is True: + if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: + accum_grad_obj_sp = get_grad_accumulate_object(param) + accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank)