pull/443/head
yingtongxiong 2023-10-25 13:38:46 +08:00
parent 1bc3c33b75
commit 476a24bd9b
3 changed files with 11 additions and 10 deletions

View File

@ -1,6 +1,6 @@
from .parallel_context import (
IS_TENSOR_PARALLEL,
IS_SEQUENCE_PARALLEL,
IS_TENSOR_PARALLEL,
Config,
ParallelContext,
global_context,

View File

@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, IS_SEQUENCE_PARALLEL, ParallelMode
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context.parallel_context import global_context as gpc
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
from internlm.model.embedding import Embedding1D
@ -365,7 +365,7 @@ class PackedFlashInternLm1D(nn.Module):
for param in self.norm.parameters():
if gpc.config.parallel.sequence_parallel is True:
setattr(param, IS_SEQUENCE_PARALLEL, True)
self.parallel_output = parallel_output
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):

View File

@ -8,7 +8,7 @@ import torch
import torch.distributed as dist
from torch.optim import Optimizer
from internlm.core.context import Config, ParallelMode, IS_SEQUENCE_PARALLEL
from internlm.core.context import IS_SEQUENCE_PARALLEL, Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.monitor import send_alert_message
from internlm.solver.optimizer.store import (
@ -296,7 +296,7 @@ class HybridZeroOptimizer(BaseOptimizer):
param=param,
reduce_rank=reduce_rank,
)
def reduction_sp_func():
handle = reduce_tensor(
param.grad,
@ -312,18 +312,19 @@ class HybridZeroOptimizer(BaseOptimizer):
def reduce_grad_hook(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_func()
# define hook for sequence_parallel
def reduce_grad_hook_sp(*args):
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
if self.skip_grad_reduce is False:
reduction_sp_func()
# if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group
# if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)