From 792b066f151c438a6ba653a8aafe9207a459907a Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Wed, 11 Oct 2023 10:57:12 +0800 Subject: [PATCH] communication overlap --- configs/7B_sft.py | 2 +- internlm/model/linear.py | 74 +++++++++++++++++++++++++++++++++++++++- internlm/model/utils.py | 6 ++-- 3 files changed, 78 insertions(+), 4 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 814966b..e8be167 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=1, fsdp=False), - tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True + tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, ) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 8e23871..36f64f3 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Optional +from typing import Optional, Union, Any import torch import torch.nn.functional as F @@ -211,6 +211,7 @@ class FeedForward(nn.Module): class FSTPLinear(ColumnParallelLinear): def forward(self, x): + import pdb; pdb.set_trace() return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) @@ -278,3 +279,74 @@ class FSTPFeedForward(nn.Module): w2_o = self.w2(x) out = self.w3(F.silu(w1_o) * w2_o) return out + +class FSTPAllGatherSyncHandler: + """ + All-gather handler for overlapping the all-gather in adjcent FSTP linear. + """ + + def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: + + self.process_group = process_group + self.FSTP_modules = [] + self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] + self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward + self.module_handler = dict() # key: FSTP module; value: all-gather handler + self.module_block = dict() # key: FSTP module; value: transformer block index + self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} + self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name + + # just want to share same for loop for ModuleList and Module + if not isinstance(model, nn.ModuleList): + model = [model] + + for _chunk in model: + if isinstance(_chunk, NaiveAMPModel): + _chunk = _chunk.model + + for _, children in _chunk.named_children(): + if isinstance(children, nn.ModuleList): + for _, block in enumerate(children): + index = 0 + sub_modules = list(block.children()) + if len(sub_modules) > 0: + for name, child in block.named_children(): + if isinstance(child, FSTPLinear): + self.FSTP_modules.append(child) + self.module_block[child] = _ + self.block_module[_][index] = child + self.module_name_index[child] = index + index = index + 1 + else: + continue + + + def _register_sync_parameters_hook(self) -> None: + """ + register pre_forward_hook and pre_backward_hook for FSTPLinear. + """ + + def _hook(module: nn.Module): + block_index = self.module_block[module] + name_index = self.module_name_index[module] + if name_index == 0: + next_module = self.block_module[block_index][name_index + 1] + self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.module_handler[next_module] = weights_handler + else: + handler = self.module_handler[module] + handler.wait() + if name_index != 4: + next_module = self.block_module[block_index][name_index + 1] + self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.module_handler[next_module] = weights_handler + + def _pre_forward_hook(module: nn.Module, inputs: Any): + _hook(module) + + def _pre_backward_hook(module: nn.Module, grad_input, grad_output): + _hook(module) + + for module in self.FSTP_modules: + module.register_forward_pre_hook(_pre_forward_hook) + module.register_backward_pre_hook(_pre_backward_hook) \ No newline at end of file diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 3885488..5768f00 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Optional +from typing import Any, Optional, Union import fused_dense_lib as fused_dense_cuda import torch @@ -379,7 +379,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() - return grad_input, grad_weight, grad_bias, None, None, None + return grad_input, grad_weight, grad_bias, None, None def fused_dense_func_torch( @@ -453,3 +453,5 @@ def Silu(w1_o, w2_o): Silu = torch.jit.script(Silu) + +