mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): fix lint error
parent
0d693cf3a1
commit
03cc7f9b80
|
@ -204,27 +204,27 @@ class FSTPOverlapHandler:
|
||||||
register forward hooks and backward hooks for fstp modules.
|
register forward hooks and backward hooks for fstp modules.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any):
|
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||||
self._all_gather_block_weight_memory_pool(0)
|
self._all_gather_block_weight_memory_pool(0)
|
||||||
|
|
||||||
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_out_proj(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||||
|
|
||||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any):
|
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||||
handle = self.fstp_global_handle[module]
|
handle = self.fstp_global_handle[module]
|
||||||
handle.wait()
|
handle.wait()
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
bias_handle = self.bias_global_handle[module]
|
bias_handle = self.bias_global_handle[module]
|
||||||
bias_handle.wait()
|
bias_handle.wait()
|
||||||
|
|
||||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any):
|
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||||
if module in self.fstp_global_handle:
|
if module in self.fstp_global_handle:
|
||||||
del self.fstp_global_handle[module]
|
del self.fstp_global_handle[module]
|
||||||
|
|
||||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
||||||
first_backward_module = self.fstp_modules[-1]
|
first_backward_module = self.fstp_modules[-1]
|
||||||
weight_handle = all_gather_raw_memory_pool(
|
weight_handle = all_gather_raw_memory_pool(
|
||||||
first_backward_module.weight,
|
first_backward_module.weight,
|
||||||
|
@ -234,7 +234,7 @@ class FSTPOverlapHandler:
|
||||||
)
|
)
|
||||||
self.fstp_global_handle[first_backward_module] = weight_handle
|
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||||
|
|
||||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output):
|
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||||
# wait handle for current module
|
# wait handle for current module
|
||||||
weight_handle = self.fstp_global_handle[module]
|
weight_handle = self.fstp_global_handle[module]
|
||||||
weight_handle.wait()
|
weight_handle.wait()
|
||||||
|
@ -251,7 +251,7 @@ class FSTPOverlapHandler:
|
||||||
)
|
)
|
||||||
self.fstp_global_handle[next_module] = weight_handle
|
self.fstp_global_handle[next_module] = weight_handle
|
||||||
|
|
||||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
||||||
if module in self.fstp_global_handle:
|
if module in self.fstp_global_handle:
|
||||||
del self.fstp_global_handle[module]
|
del self.fstp_global_handle[module]
|
||||||
|
|
||||||
|
|
|
@ -129,7 +129,6 @@ def all_gather_raw_memory_pool(
|
||||||
input_: Tensor,
|
input_: Tensor,
|
||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
async_op: bool = False,
|
async_op: bool = False,
|
||||||
gather_dim: int = 0,
|
|
||||||
module: nn.Module = None,
|
module: nn.Module = None,
|
||||||
):
|
):
|
||||||
handle = torch.distributed.all_gather_into_tensor(
|
handle = torch.distributed.all_gather_into_tensor(
|
||||||
|
@ -145,7 +144,6 @@ def all_gather_raw_bias_memory_pool(
|
||||||
input_: Tensor,
|
input_: Tensor,
|
||||||
process_group: ProcessGroup,
|
process_group: ProcessGroup,
|
||||||
async_op: bool = False,
|
async_op: bool = False,
|
||||||
gather_dim: int = 0,
|
|
||||||
module: nn.Module = None,
|
module: nn.Module = None,
|
||||||
):
|
):
|
||||||
handle = torch.distributed.all_gather_into_tensor(
|
handle = torch.distributed.all_gather_into_tensor(
|
||||||
|
@ -283,8 +281,8 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||||
class MegatronFusedDenseFunc(torch.autograd.Function):
|
class MegatronFusedDenseFunc(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
FusedDenseFunc for tensor parallel in megatron implementation.
|
FusedDenseFunc for tensor parallel in megatron implementation.
|
||||||
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be saved for backward in megatron,
|
The diffenrence between the implementation of flash-attn and megatron is that the total_x could be
|
||||||
so that the all-gather in backward is ommited.
|
saved for backward in megatron, so that the all-gather in backward is ommited.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -433,7 +431,6 @@ class MegatronFusedDenseFuncTorch(FusedDenseFunc):
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
process_group = ctx.process_group
|
process_group = ctx.process_group
|
||||||
sequence_parallel = ctx.sequence_parallel
|
sequence_parallel = ctx.sequence_parallel
|
||||||
gather_dim = ctx.gather_dim
|
|
||||||
if ctx.compute_weight_gradient:
|
if ctx.compute_weight_gradient:
|
||||||
total_x, weight = ctx.saved_tensors
|
total_x, weight = ctx.saved_tensors
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue