From 476a24bd9b898a97d6aa9b7bdeb4fe7291e690ad Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Wed, 25 Oct 2023 13:38:46 +0800 Subject: [PATCH] fix lint --- internlm/core/context/__init__.py | 2 +- internlm/model/modeling_internlm.py | 4 ++-- internlm/solver/optimizer/hybrid_zero_optim.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/internlm/core/context/__init__.py b/internlm/core/context/__init__.py index 603f24f..6f1142c 100644 --- a/internlm/core/context/__init__.py +++ b/internlm/core/context/__init__.py @@ -1,6 +1,6 @@ from .parallel_context import ( - IS_TENSOR_PARALLEL, IS_SEQUENCE_PARALLEL, + IS_TENSOR_PARALLEL, Config, ParallelContext, global_context, diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 9102225..cbf425c 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -9,7 +9,7 @@ from flash_attn.modules.embedding import ParallelGPT2Embeddings from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn -from internlm.core.context import IS_TENSOR_PARALLEL, IS_SEQUENCE_PARALLEL, ParallelMode +from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.embedding import Embedding1D @@ -365,7 +365,7 @@ class PackedFlashInternLm1D(nn.Module): for param in self.norm.parameters(): if gpc.config.parallel.sequence_parallel is True: setattr(param, IS_SEQUENCE_PARALLEL, True) - + self.parallel_output = parallel_output def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 3487324..2817258 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -8,7 +8,7 @@ import torch import torch.distributed as dist from torch.optim import Optimizer -from internlm.core.context import Config, ParallelMode, IS_SEQUENCE_PARALLEL +from internlm.core.context import IS_SEQUENCE_PARALLEL, Config, ParallelMode from internlm.core.context import global_context as gpc from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( @@ -296,7 +296,7 @@ class HybridZeroOptimizer(BaseOptimizer): param=param, reduce_rank=reduce_rank, ) - + def reduction_sp_func(): handle = reduce_tensor( param.grad, @@ -312,18 +312,19 @@ class HybridZeroOptimizer(BaseOptimizer): def reduce_grad_hook(*args): # pylint: disable=W0613 if self.skip_grad_reduce is False: reduction_func() - + # define hook for sequence_parallel - def reduce_grad_hook_sp(*args): + def reduce_grad_hook_sp(*args): # pylint: disable=W0613 if self.skip_grad_reduce is False: reduction_sp_func() - - # if sequence_parallel is True, the grad of norm should be all-reduce across the tp process group + + # if sequence_parallel is True, + # the grad of norm should be all-reduce across the tp process group if gpc.config.parallel.sequence_parallel is True: if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: accum_grad_obj_sp = get_grad_accumulate_object(param) accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) - + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank)