communication overlap

pull/407/head
yingtongxiong 2023-10-11 10:57:12 +08:00
parent c94be64fd2
commit 792b066f15
3 changed files with 78 additions and 4 deletions

View File

@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False.
"""
parallel = dict(
zero1=dict(size=1, fsdp=False),
tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=True,
)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
from typing import Optional, Union, Any
import torch
import torch.nn.functional as F
@ -211,6 +211,7 @@ class FeedForward(nn.Module):
class FSTPLinear(ColumnParallelLinear):
def forward(self, x):
import pdb; pdb.set_trace()
return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
@ -278,3 +279,74 @@ class FSTPFeedForward(nn.Module):
w2_o = self.w2(x)
out = self.w3(F.silu(w1_o) * w2_o)
return out
class FSTPAllGatherSyncHandler:
"""
All-gather handler for overlapping the all-gather in adjcent FSTP linear.
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None:
self.process_group = process_group
self.FSTP_modules = []
self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"]
self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward
self.module_handler = dict() # key: FSTP module; value: all-gather handler
self.module_block = dict() # key: FSTP module; value: transformer block index
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
if isinstance(children, nn.ModuleList):
for _, block in enumerate(children):
index = 0
sub_modules = list(block.children())
if len(sub_modules) > 0:
for name, child in block.named_children():
if isinstance(child, FSTPLinear):
self.FSTP_modules.append(child)
self.module_block[child] = _
self.block_module[_][index] = child
self.module_name_index[child] = index
index = index + 1
else:
continue
def _register_sync_parameters_hook(self) -> None:
"""
register pre_forward_hook and pre_backward_hook for FSTPLinear.
"""
def _hook(module: nn.Module):
block_index = self.module_block[module]
name_index = self.module_name_index[module]
if name_index == 0:
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler
else:
handler = self.module_handler[module]
handler.wait()
if name_index != 4:
next_module = self.block_module[block_index][name_index + 1]
self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True)
self.module_handler[next_module] = weights_handler
def _pre_forward_hook(module: nn.Module, inputs: Any):
_hook(module)
def _pre_backward_hook(module: nn.Module, grad_input, grad_output):
_hook(module)
for module in self.FSTP_modules:
module.register_forward_pre_hook(_pre_forward_hook)
module.register_backward_pre_hook(_pre_backward_hook)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional
from typing import Any, Optional, Union
import fused_dense_lib as fused_dense_cuda
import torch
@ -379,7 +379,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function):
handle_grad_weight.wait()
if grad_bias is not None:
handle_grad_bias.wait()
return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None
def fused_dense_func_torch(
@ -453,3 +453,5 @@ def Silu(w1_o, w2_o):
Silu = torch.jit.script(Silu)