From 03cc7f9b80bc94c4b3234da8d32674189c66aa5f Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 15:28:34 +0800 Subject: [PATCH] feat(model/overlap_handler.py): fix lint error --- internlm/model/overlap_handler.py | 14 +++++++------- internlm/model/utils.py | 7 ++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 3f7ee05..6870fe6 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -204,27 +204,27 @@ class FSTPOverlapHandler: register forward hooks and backward hooks for fstp modules. """ - def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): + def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 self._all_gather_block_weight_memory_pool(0) - def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): + def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613 block_index = self.module_to_index[module] # start the all-gather for next block if block_index + 1 < gpc.config.NUM_LAYER: self._all_gather_block_weight_memory_pool(block_index + 1) - def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613 handle = self.fstp_global_handle[module] handle.wait() if module.bias is not None: bias_handle = self.bias_global_handle[module] bias_handle.wait() - def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): + def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 if module in self.fstp_global_handle: del self.fstp_global_handle[module] - def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): + def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613 first_backward_module = self.fstp_modules[-1] weight_handle = all_gather_raw_memory_pool( first_backward_module.weight, @@ -234,7 +234,7 @@ class FSTPOverlapHandler: ) self.fstp_global_handle[first_backward_module] = weight_handle - def _pre_backward_hook_for_module(module: nn.Module, grad_output): + def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 # wait handle for current module weight_handle = self.fstp_global_handle[module] weight_handle.wait() @@ -251,7 +251,7 @@ class FSTPOverlapHandler: ) self.fstp_global_handle[next_module] = weight_handle - def _post_backward_hook_for_module(module, grad_input, grad_output): + def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613 if module in self.fstp_global_handle: del self.fstp_global_handle[module] diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 42a8400..982c0e0 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -129,7 +129,6 @@ def all_gather_raw_memory_pool( input_: Tensor, process_group: ProcessGroup, async_op: bool = False, - gather_dim: int = 0, module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( @@ -145,7 +144,6 @@ def all_gather_raw_bias_memory_pool( input_: Tensor, process_group: ProcessGroup, async_op: bool = False, - gather_dim: int = 0, module: nn.Module = None, ): handle = torch.distributed.all_gather_into_tensor( @@ -283,8 +281,8 @@ class FusedDenseFunc(torch.autograd.Function): class MegatronFusedDenseFunc(torch.autograd.Function): """ FusedDenseFunc for tensor parallel in megatron implementation. - The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron, - so that the all-gather in backward is ommited. + The diffenrence between the implementation of flash-attn and megatron is that the total_x could be + saved for backward in megatron, so that the all-gather in backward is ommited. """ @staticmethod @@ -433,7 +431,6 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc): grad_input = grad_input.contiguous() process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel - gather_dim = ctx.gather_dim if ctx.compute_weight_gradient: total_x, weight = ctx.saved_tensors else: