mirror of https://github.com/InternLM/InternLM
fix lint
parent
1bc3c33b75
commit
476a24bd9b
|
@ -1,6 +1,6 @@
|
|||
from .parallel_context import (
|
||||
IS_TENSOR_PARALLEL,
|
||||
IS_SEQUENCE_PARALLEL,
|
||||
IS_TENSOR_PARALLEL,
|
||||
Config,
|
||||
ParallelContext,
|
||||
global_context,
|
||||
|
|
|
@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
|||
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, IS_SEQUENCE_PARALLEL, ParallelMode
|
||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context.parallel_context import global_context as gpc
|
||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||
from internlm.model.embedding import Embedding1D
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from internlm.core.context import Config, ParallelMode, IS_SEQUENCE_PARALLEL
|
||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, Config, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer.store import (
|
||||
|
@ -314,11 +314,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
reduction_func()
|
||||
|
||||
# define hook for sequence_parallel
|
||||
def reduce_grad_hook_sp(*args):
|
||||
def reduce_grad_hook_sp(*args): # pylint: disable=W0613
|
||||
if self.skip_grad_reduce is False:
|
||||
reduction_sp_func()
|
||||
|
||||
# if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group
|
||||
# if sequence_parallel is True,
|
||||
# the grad of norm should be all-reduce across the tp process group
|
||||
if gpc.config.parallel.sequence_parallel is True:
|
||||
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
||||
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
||||
|
|
Loading…
Reference in New Issue