mirror of https://github.com/InternLM/InternLM
fix(*)/all-reduce for norm in sequence parallel (#443)
* fix all-reduce norm grad * change the order of dp and sp all-reduce * fix lintpull/450/head
parent
949a0a1d55
commit
1d7e2d04ec
|
@ -1,4 +1,5 @@
|
||||||
from .parallel_context import (
|
from .parallel_context import (
|
||||||
|
IS_SEQUENCE_PARALLEL,
|
||||||
IS_TENSOR_PARALLEL,
|
IS_TENSOR_PARALLEL,
|
||||||
Config,
|
Config,
|
||||||
ParallelContext,
|
ParallelContext,
|
||||||
|
@ -29,6 +30,7 @@ from .random import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Config",
|
"Config",
|
||||||
"IS_TENSOR_PARALLEL",
|
"IS_TENSOR_PARALLEL",
|
||||||
|
"IS_SEQUENCE_PARALLEL",
|
||||||
"global_context",
|
"global_context",
|
||||||
"ParallelContext",
|
"ParallelContext",
|
||||||
"ParallelMode",
|
"ParallelMode",
|
||||||
|
|
|
@ -25,6 +25,7 @@ from .process_group_initializer import ParallelMode
|
||||||
from .random import add_seed, get_seeds, set_mode
|
from .random import add_seed, get_seeds, set_mode
|
||||||
|
|
||||||
IS_TENSOR_PARALLEL = "is_tensor_parallel"
|
IS_TENSOR_PARALLEL = "is_tensor_parallel"
|
||||||
|
IS_SEQUENCE_PARALLEL = "is_sequence_parallel"
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
|
@ -134,6 +134,12 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
for _, param in self.mlp.named_parameters():
|
for _, param in self.mlp.named_parameters():
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
|
for param in self.norm1.parameters():
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||||
|
for param in self.norm2.parameters():
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||||
|
|
||||||
self.dropout2 = nn.Dropout(drop_rate)
|
self.dropout2 = nn.Dropout(drop_rate)
|
||||||
self.use_swiglu = use_swiglu
|
self.use_swiglu = use_swiglu
|
||||||
|
@ -356,6 +362,10 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
normal_(std=0.0052)(param)
|
normal_(std=0.0052)(param)
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||||
|
for param in self.norm.parameters():
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
setattr(param, IS_SEQUENCE_PARALLEL, True)
|
||||||
|
|
||||||
self.parallel_output = parallel_output
|
self.parallel_output = parallel_output
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from internlm.core.context import Config, ParallelMode
|
from internlm.core.context import IS_SEQUENCE_PARALLEL, Config, ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.monitor import send_alert_message
|
from internlm.monitor import send_alert_message
|
||||||
from internlm.solver.optimizer.store import (
|
from internlm.solver.optimizer.store import (
|
||||||
|
@ -297,6 +297,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
reduce_rank=reduce_rank,
|
reduce_rank=reduce_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def reduction_sp_func():
|
||||||
|
handle = reduce_tensor(
|
||||||
|
param.grad,
|
||||||
|
dtype=None,
|
||||||
|
dst_rank=reduce_rank,
|
||||||
|
parallel_mode=ParallelMode.TENSOR,
|
||||||
|
)
|
||||||
|
handle.wait()
|
||||||
|
|
||||||
# define hook
|
# define hook
|
||||||
# NOT IMPORTANT BUT GOOD TO KNOW:
|
# NOT IMPORTANT BUT GOOD TO KNOW:
|
||||||
# args here is not grad, but allow_unreacable and accumulate_grad
|
# args here is not grad, but allow_unreacable and accumulate_grad
|
||||||
|
@ -304,6 +313,18 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if self.skip_grad_reduce is False:
|
if self.skip_grad_reduce is False:
|
||||||
reduction_func()
|
reduction_func()
|
||||||
|
|
||||||
|
# define hook for sequence_parallel
|
||||||
|
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
||||||
|
if self.skip_grad_reduce is False:
|
||||||
|
reduction_sp_func()
|
||||||
|
|
||||||
|
# if sequence_parallel is True,
|
||||||
|
# the grad of norm should be all-reduce across the tp process group
|
||||||
|
if gpc.config.parallel.sequence_parallel is True:
|
||||||
|
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
||||||
|
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
||||||
|
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
|
||||||
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
|
Loading…
Reference in New Issue