mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): fix lint error
parent
0d693cf3a1
commit
03cc7f9b80
|
@ -204,27 +204,27 @@ class FSTPOverlapHandler:
|
|||
register forward hooks and backward hooks for fstp modules.
|
||||
"""
|
||||
|
||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any):
|
||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
self._all_gather_block_weight_memory_pool(0)
|
||||
|
||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
|
||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
block_index = self.module_to_index[module]
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
handle = self.fstp_global_handle[module]
|
||||
handle.wait()
|
||||
if module.bias is not None:
|
||||
bias_handle = self.bias_global_handle[module]
|
||||
bias_handle.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any):
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
del self.fstp_global_handle[module]
|
||||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
||||
first_backward_module = self.fstp_modules[-1]
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
first_backward_module.weight,
|
||||
|
@ -234,7 +234,7 @@ class FSTPOverlapHandler:
|
|||
)
|
||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||
# wait handle for current module
|
||||
weight_handle = self.fstp_global_handle[module]
|
||||
weight_handle.wait()
|
||||
|
@ -251,7 +251,7 @@ class FSTPOverlapHandler:
|
|||
)
|
||||
self.fstp_global_handle[next_module] = weight_handle
|
||||
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
del self.fstp_global_handle[module]
|
||||
|
||||
|
|
|
@ -129,7 +129,6 @@ def all_gather_raw_memory_pool(
|
|||
input_: Tensor,
|
||||
process_group: ProcessGroup,
|
||||
async_op: bool = False,
|
||||
gather_dim: int = 0,
|
||||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
|
@ -145,7 +144,6 @@ def all_gather_raw_bias_memory_pool(
|
|||
input_: Tensor,
|
||||
process_group: ProcessGroup,
|
||||
async_op: bool = False,
|
||||
gather_dim: int = 0,
|
||||
module: nn.Module = None,
|
||||
):
|
||||
handle = torch.distributed.all_gather_into_tensor(
|
||||
|
@ -283,8 +281,8 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||
class MegatronFusedDenseFunc(torch.autograd.Function):
|
||||
"""
|
||||
FusedDenseFunc for tensor parallel in megatron implementation.
|
||||
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
|
||||
so that the all-gather in backward is ommited.
|
||||
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be
|
||||
saved for backward in megatron, so that the all-gather in backward is ommited.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
|
@ -433,7 +431,6 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc):
|
|||
grad_input = grad_input.contiguous()
|
||||
process_group = ctx.process_group
|
||||
sequence_parallel = ctx.sequence_parallel
|
||||
gather_dim = ctx.gather_dim
|
||||
if ctx.compute_weight_gradient:
|
||||
total_x, weight = ctx.saved_tensors
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue