mirror of https://github.com/InternLM/InternLM
communication overlap
parent
c94be64fd2
commit
792b066f15
|
@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
|
||||||
"""
|
"""
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=1, fsdp=False),
|
zero1=dict(size=1, fsdp=False),
|
||||||
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
|
||||||
pipeline=dict(size=1, interleaved_overlap=True),
|
pipeline=dict(size=1, interleaved_overlap=True),
|
||||||
sequence_parallel=True,
|
sequence_parallel=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional, Union, Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -211,6 +211,7 @@ class FeedForward(nn.Module):
|
||||||
|
|
||||||
class FSTPLinear(ColumnParallelLinear):
|
class FSTPLinear(ColumnParallelLinear):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
||||||
|
|
||||||
|
|
||||||
|
@ -278,3 +279,74 @@ class FSTPFeedForward(nn.Module):
|
||||||
w2_o = self.w2(x)
|
w2_o = self.w2(x)
|
||||||
out = self.w3(F.silu(w1_o) * w2_o)
|
out = self.w3(F.silu(w1_o) * w2_o)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class FSTPAllGatherSyncHandler:
|
||||||
|
"""
|
||||||
|
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
|
||||||
|
|
||||||
|
self.process_group = process_group
|
||||||
|
self.FSTP_modules = []
|
||||||
|
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
|
||||||
|
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
|
||||||
|
self.module_handler = dict() # key: FSTP module; value: all-gather handler
|
||||||
|
self.module_block = dict() # key: FSTP module; value: transformer block index
|
||||||
|
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||||
|
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||||
|
|
||||||
|
# just want to share same for loop for ModuleList and Module
|
||||||
|
if not isinstance(model, nn.ModuleList):
|
||||||
|
model = [model]
|
||||||
|
|
||||||
|
for _chunk in model:
|
||||||
|
if isinstance(_chunk, NaiveAMPModel):
|
||||||
|
_chunk = _chunk.model
|
||||||
|
|
||||||
|
for _, children in _chunk.named_children():
|
||||||
|
if isinstance(children, nn.ModuleList):
|
||||||
|
for _, block in enumerate(children):
|
||||||
|
index = 0
|
||||||
|
sub_modules = list(block.children())
|
||||||
|
if len(sub_modules) > 0:
|
||||||
|
for name, child in block.named_children():
|
||||||
|
if isinstance(child, FSTPLinear):
|
||||||
|
self.FSTP_modules.append(child)
|
||||||
|
self.module_block[child] = _
|
||||||
|
self.block_module[_][index] = child
|
||||||
|
self.module_name_index[child] = index
|
||||||
|
index = index + 1
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
|
"""
|
||||||
|
register pre_forward_hook and pre_backward_hook for FSTPLinear.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _hook(module: nn.Module):
|
||||||
|
block_index = self.module_block[module]
|
||||||
|
name_index = self.module_name_index[module]
|
||||||
|
if name_index == 0:
|
||||||
|
next_module = self.block_module[block_index][name_index + 1]
|
||||||
|
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||||
|
self.module_handler[next_module] = weights_handler
|
||||||
|
else:
|
||||||
|
handler = self.module_handler[module]
|
||||||
|
handler.wait()
|
||||||
|
if name_index != 4:
|
||||||
|
next_module = self.block_module[block_index][name_index + 1]
|
||||||
|
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
|
||||||
|
self.module_handler[next_module] = weights_handler
|
||||||
|
|
||||||
|
def _pre_forward_hook(module: nn.Module, inputs: Any):
|
||||||
|
_hook(module)
|
||||||
|
|
||||||
|
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
|
||||||
|
_hook(module)
|
||||||
|
|
||||||
|
for module in self.FSTP_modules:
|
||||||
|
module.register_forward_pre_hook(_pre_forward_hook)
|
||||||
|
module.register_backward_pre_hook(_pre_backward_hook)
|
|
@ -1,7 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import fused_dense_lib as fused_dense_cuda
|
import fused_dense_lib as fused_dense_cuda
|
||||||
import torch
|
import torch
|
||||||
|
@ -379,7 +379,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
|
||||||
handle_grad_weight.wait()
|
handle_grad_weight.wait()
|
||||||
if grad_bias is not None:
|
if grad_bias is not None:
|
||||||
handle_grad_bias.wait()
|
handle_grad_bias.wait()
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None
|
return grad_input, grad_weight, grad_bias, None, None
|
||||||
|
|
||||||
|
|
||||||
def fused_dense_func_torch(
|
def fused_dense_func_torch(
|
||||||
|
@ -453,3 +453,5 @@ def Silu(w1_o, w2_o):
|
||||||
|
|
||||||
|
|
||||||
Silu = torch.jit.script(Silu)
|
Silu = torch.jit.script(Silu)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue