From 5fd5a8a32b0e045a499e8beb3f0438cb0bd49408 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Wed, 11 Oct 2023 17:36:41 +0800 Subject: [PATCH] support fine-grained overlap --- configs/7B_sft.py | 2 +- internlm/model/linear.py | 78 ++++++++++++++++++-------- internlm/model/multi_head_attention.py | 3 +- internlm/model/utils.py | 35 +++++++----- internlm/train/training_internlm.py | 8 ++- 5 files changed, 86 insertions(+), 40 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index e8be167..814966b 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -162,7 +162,7 @@ sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( zero1=dict(size=1, fsdp=False), - tensor=dict(size=2, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True + tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, ) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 36f64f3..42bd9f0 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -11,7 +11,8 @@ from torch import nn from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch +from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.utils import Silu, fstp_fused_dense_func, fused_dense_func_torch, all_gather_raw class ScaleColumnParallelLinear(nn.Linear): @@ -211,8 +212,7 @@ class FeedForward(nn.Module): class FSTPLinear(ColumnParallelLinear): def forward(self, x): - import pdb; pdb.set_trace() - return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group) + return fstp_fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, module=self, handler=gpc.config.fstp_handler) class FSTPFeedForward(nn.Module): @@ -287,6 +287,7 @@ class FSTPAllGatherSyncHandler: def __init__(self, model: Union[nn.Module, nn.ModuleList], process_group) -> None: + # import pdb; pdb.set_trace() self.process_group = process_group self.FSTP_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] @@ -306,19 +307,21 @@ class FSTPAllGatherSyncHandler: for _, children in _chunk.named_children(): if isinstance(children, nn.ModuleList): - for _, block in enumerate(children): + for idx, block in enumerate(children): index = 0 - sub_modules = list(block.children()) - if len(sub_modules) > 0: - for name, child in block.named_children(): - if isinstance(child, FSTPLinear): - self.FSTP_modules.append(child) - self.module_block[child] = _ - self.block_module[_][index] = child - self.module_name_index[child] = index - index = index + 1 - else: - continue + self.block_module[idx] = {} + for _, sub in block.named_children(): + sub_modules = list(sub.children()) + if len(sub_modules) > 0: + for name, child in sub.named_children(): + if isinstance(child, FSTPLinear): + self.FSTP_modules.append(child) + self.module_block[child] = idx + self.block_module[idx][index] = child + self.module_name_index[child] = index + index = index + 1 + else: + continue def _register_sync_parameters_hook(self) -> None: @@ -326,27 +329,58 @@ class FSTPAllGatherSyncHandler: register pre_forward_hook and pre_backward_hook for FSTPLinear. """ - def _hook(module: nn.Module): + def _pre_forward_hook(module: nn.Module, inputs: Any): block_index = self.module_block[module] name_index = self.module_name_index[module] if name_index == 0: + total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handler.wait() + self.FSTP_global_weights[module] = total_weight + + # start the all-gather for next module next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) self.module_handler[next_module] = weights_handler else: handler = self.module_handler[module] handler.wait() if name_index != 4: next_module = self.block_module[block_index][name_index + 1] - self.FSTP_global_weights, weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) self.module_handler[next_module] = weights_handler - def _pre_forward_hook(module: nn.Module, inputs: Any): - _hook(module) + def _post_forward_hook(module: nn.Module, input, output): + del self.FSTP_global_weights[module] + del self.module_handler[module] def _pre_backward_hook(module: nn.Module, grad_input, grad_output): - _hook(module) + block_index = self.module_block[module] + name_index = self.module_name_index[module] + if name_index == 4: + total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handler.wait() + self.FSTP_global_weights[module] = total_weight + + # start the all-gather for next module + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.module_handler[next_module] = weights_handler + else: + handler = self.module_handler[module] + handler.wait() + if name_index != 0: + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(next_module.weight, self.process_group, async_op=True) + self.module_handler[next_module] = weights_handler + + def _post_backward_hook(module, grad_input, grad_output): + del self.FSTP_global_weights[module] for module in self.FSTP_modules: + # import pdb; pdb.set_trace() module.register_forward_pre_hook(_pre_forward_hook) - module.register_backward_pre_hook(_pre_backward_hook) \ No newline at end of file + module.register_forward_hook(_post_forward_hook) + # module.register_backward_pre_hook(_pre_backward_hook) + # module.register_backward_hook(_post_backward_hook) + module.register_module_full_backward_pre_hook(_pre_backward_hook) + \ No newline at end of file diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 436caf7..1db98d7 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -210,7 +210,7 @@ class MHA(nn.Module): embed_dim, 3 * embed_dim, process_group, - bias=True, + bias=False, sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) # according to https://spaces.ac.cn/archives/9577 @@ -231,6 +231,7 @@ class MHA(nn.Module): embed_dim, embed_dim, process_group, + bias=False, sequence_parallel=gpc.config.parallel.sequence_parallel, **factory_kwargs, ) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 5768f00..50b9bbd 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -283,11 +283,13 @@ class FSTPFusedDenseFunc(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, weight, bias, return_residual=False, process_group=None): + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, module=None, all_gather_handler=None): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.process_group = process_group + ctx.all_gather_handler = all_gather_handler + ctx.module = module if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) @@ -295,14 +297,16 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: - # do all_gather for weight and bias before actual computation - total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - if bias is not None: - total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) - handle_bias.wait() - else: - total_bias = bias - handle_weight.wait() + total_weight = all_gather_handler.FSTP_global_weights[module] + total_bias = bias + # # do all_gather for weight and bias before actual computation + # total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + # if bias is not None: + # total_bias, handle_bias = all_gather_raw(bias, process_group, async_op=True) + # handle_bias.wait() + # else: + # total_bias = bias + # handle_weight.wait() else: total_weight = weight total_bias = bias @@ -332,6 +336,8 @@ class FSTPFusedDenseFunc(torch.autograd.Function): (grad_input,) = args grad_input = grad_input.contiguous() process_group = ctx.process_group + all_gather_handler = ctx.all_gather_handler + module = ctx.module if ctx.compute_weight_gradient: x, weight = ctx.saved_tensors total_x = x @@ -345,8 +351,9 @@ class FSTPFusedDenseFunc(torch.autograd.Function): world_size = gpc.get_world_size(ParallelMode.TENSOR) if world_size > 1: # do all-gather for weight before backward - total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) - handle_weight.wait() + # total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) + # handle_weight.wait() + total_weight = all_gather_handler.FSTP_global_weights[module] else: total_weight = weight @@ -379,7 +386,7 @@ class FSTPFusedDenseFunc(torch.autograd.Function): handle_grad_weight.wait() if grad_bias is not None: handle_grad_bias.wait() - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None, None def fused_dense_func_torch( @@ -401,13 +408,13 @@ def fused_dense_func_torch( def fstp_fused_dense_func( - x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None + x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, return_residual: bool = False, process_group=None, module=None, handler=None ): dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: - return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group) + return FSTPFusedDenseFunc.apply(x, weight, bias, return_residual, process_group, module, handler) else: assert process_group is None out = F.linear(x, weight, bias) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 7af58dd..5deb023 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -39,6 +39,7 @@ from internlm.model.linear import ( FeedForward, RewardModelLinear, ScaleColumnParallelLinear, + FSTPAllGatherSyncHandler, ) from internlm.model.multi_head_attention import MHA from internlm.model.utils import try_import_RMSNorm @@ -106,10 +107,13 @@ def initialize_model(): # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) - + + if gpc.config.parallel["tensor"]["mode"] == "fstp": + handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) + handler._register_sync_parameters_hook() + gpc.config.fstp_handler = handler return model - def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): if gpc.config.parallel.zero1.fsdp: # set wrap_policy for fsdp wrap