mirror of https://github.com/InternLM/InternLM
communication overlap
parent
c94be64fd2
commit
792b066f15
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=1, fsdp=False),
|
||||
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
sequence_parallel=True,
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -211,6 +211,7 @@ class FeedForward(nn.Module):
|
|||
|
||||
class FSTPLinear(ColumnParallelLinear):
|
||||
def forward(self, x):
|
||||
import pdb; pdb.set_trace()
|
||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
||||
|
||||
|
||||
|
@ -278,3 +279,74 @@ class FSTPFeedForward(nn.Module):
|
|||
w2_o = self.w2(x)
|
||||
out = self.w3(F.silu(w1_o) * w2_o)
|
||||
return out
|
||||
|
||||
class FSTPAllGatherSyncHandler:
|
||||
"""
|
||||
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||
|
||||
self.process_group = process_group
|
||||
self.FSTP_modules = []
|
||||
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
||||
self.module_block = dict() # key: FSTP module; value: transformer block index
|
||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _, children in _chunk.named_children():
|
||||
if isinstance(children, nn.ModuleList):
|
||||
for _, block in enumerate(children):
|
||||
index = 0
|
||||
sub_modules = list(block.children())
|
||||
if len(sub_modules) > 0:
|
||||
for name, child in block.named_children():
|
||||
if isinstance(child, FSTPLinear):
|
||||
self.FSTP_modules.append(child)
|
||||
self.module_block[child] = _
|
||||
self.block_module[_][index] = child
|
||||
self.module_name_index[child] = index
|
||||
index = index + 1
|
||||
else:
|
||||
continue
|
||||
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
"""
|
||||
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
||||
"""
|
||||
|
||||
def _hook(module: nn.Module):
|
||||
block_index = self.module_block[module]
|
||||
name_index = self.module_name_index[module]
|
||||
if name_index == 0:
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
else:
|
||||
handler = self.module_handler[module]
|
||||
handler.wait()
|
||||
if name_index != 4:
|
||||
next_module = self.block_module[block_index][name_index + 1]
|
||||
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||
self.module_handler[next_module] = weights_handler
|
||||
|
||||
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
||||
_hook(module)
|
||||
|
||||
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
|
||||
_hook(module)
|
||||
|
||||
for module in self.FSTP_modules:
|
||||
module.register_forward_pre_hook(_pre_forward_hook)
|
||||
module.register_backward_pre_hook(_pre_backward_hook)
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import fused_dense_lib as fused_dense_cuda
|
||||
import torch
|
||||
|
@ -379,7 +379,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
|||
handle_grad_weight.wait()
|
||||
if grad_bias is not None:
|
||||
handle_grad_bias.wait()
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
return grad_input, grad_weight, grad_bias, None, None
|
||||
|
||||
|
||||
def fused_dense_func_torch(
|
||||
|
@ -453,3 +453,5 @@ def Silu(w1_o, w2_o):
|
|||
|
||||
|
||||
Silu = torch.jit.script(Silu)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue