feat(model/overlap_handler.py): fix lint error

pull/456/head
huangting4201 2023-10-23 15:28:34 +08:00
parent 0d693cf3a1
commit 03cc7f9b80
2 changed files with 9 additions and 12 deletions

View File

@ -204,27 +204,27 @@ class FSTPOverlapHandler:
register forward hooks and backward hooks for fstp modules.
"""
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any):
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
self._all_gather_block_weight_memory_pool(0)
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
block_index = self.module_to_index[module]
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
self._all_gather_block_weight_memory_pool(block_index + 1)
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
handle = self.fstp_global_handle[module]
handle.wait()
if module.bias is not None:
bias_handle = self.bias_global_handle[module]
bias_handle.wait()
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any):
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
if module in self.fstp_global_handle:
del self.fstp_global_handle[module]
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
first_backward_module = self.fstp_modules[-1]
weight_handle = all_gather_raw_memory_pool(
first_backward_module.weight,
@ -234,7 +234,7 @@ class FSTPOverlapHandler:
)
self.fstp_global_handle[first_backward_module] = weight_handle
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
# wait handle for current module
weight_handle = self.fstp_global_handle[module]
weight_handle.wait()
@ -251,7 +251,7 @@ class FSTPOverlapHandler:
)
self.fstp_global_handle[next_module] = weight_handle
def _post_backward_hook_for_module(module, grad_input, grad_output):
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
if module in self.fstp_global_handle:
del self.fstp_global_handle[module]

View File

@ -129,7 +129,6 @@ def all_gather_raw_memory_pool(
input_: Tensor,
process_group: ProcessGroup,
async_op: bool = False,
gather_dim: int = 0,
module: nn.Module = None,
):
handle = torch.distributed.all_gather_into_tensor(
@ -145,7 +144,6 @@ def all_gather_raw_bias_memory_pool(
input_: Tensor,
process_group: ProcessGroup,
async_op: bool = False,
gather_dim: int = 0,
module: nn.Module = None,
):
handle = torch.distributed.all_gather_into_tensor(
@ -283,8 +281,8 @@ class FusedDenseFunc(torch.autograd.Function):
class MegatronFusedDenseFunc(torch.autograd.Function):
"""
FusedDenseFunc for tensor parallel in megatron implementation.
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
so that the all-gather in backward is ommited.
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be
saved for backward in megatron, so that the all-gather in backward is ommited.
"""
@staticmethod
@ -433,7 +431,6 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc):
grad_input = grad_input.contiguous()
process_group = ctx.process_group
sequence_parallel = ctx.sequence_parallel
gather_dim = ctx.gather_dim
if ctx.compute_weight_gradient:
total_x, weight = ctx.saved_tensors
else: