From 82204eea59862b01c5aca68cad26c5060b1b7b16 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Mon, 16 Oct 2023 16:35:14 +0800 Subject: [PATCH] support hybrid overlap --- configs/7B_sft.py | 4 +- internlm/model/linear.py | 82 +++++++++++++++++++++++++---- internlm/train/training_internlm.py | 3 +- 3 files changed, 75 insertions(+), 14 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 814966b..98bceeb 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -2,10 +2,10 @@ JOB_NAME = "7b_train" DO_ALERT = False SEQ_LEN = 4096 -HIDDEN_SIZE = 4096 +HIDDEN_SIZE = 8192 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 -NUM_LAYER = 32 +NUM_LAYER = 8 VOCAB_SIZE = 103168 MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 56929ee..890f1cb 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -360,10 +360,12 @@ class FSTPAllGatherSyncHandler: self.module_handler[next_module] = weights_handler def _post_forward_hook(module: nn.Module, input, output): - del self.FSTP_global_weights[module] - del self.module_handler[module] + if module in self.FSTP_global_weights: + del self.FSTP_global_weights[module] + if module in self.module_handler: + del self.module_handler[module] - def _pre_backward_hook(module: nn.Module, grad_input, grad_output): + def _pre_backward_hook(module: nn.Module, grad_output): block_index = self.module_block[module] name_index = self.module_name_index[module] if name_index == 4: @@ -396,7 +398,8 @@ class FSTPAllGatherSyncHandler: module.register_forward_hook(_post_forward_hook) # module.register_backward_pre_hook(_pre_backward_hook) # module.register_backward_hook(_post_backward_hook) - module.register_module_full_backward_pre_hook(_pre_backward_hook) + module.register_full_backward_pre_hook(_pre_backward_hook) + module.register_full_backward_hook(_post_backward_hook) class CoarseGrainedFSTPAllGatherSyncHandler: @@ -410,6 +413,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.FSTP_blocks = [] self.FSTP_outs = [] self.FSTP_wqkvs = [] + self.FSTP_modules = [] self.module_name = ["Wqkv", "out_proj", "w1", "w2", "w3"] self.FSTP_global_handle = dict() # key: FSTP module; value: module global all-gather op handle self.FSTP_global_weights = dict() # key: FSTP module; value: module global weight for forward @@ -418,6 +422,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.block_to_index = dict() # key: transformer block; value: transformer block index self.index_to_block = dict() # key: transformer block index; value: transformer block self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules + self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name + self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} # just want to share same for loop for ModuleList and Module if not isinstance(model, nn.ModuleList): @@ -430,6 +436,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler: for _, children in _chunk.named_children(): if isinstance(children, nn.ModuleList): for idx, block in enumerate(children): + index = 0 + self.block_module[idx] = {} self.FSTP_blocks.append(block) self.block_to_index[block] = idx self.index_to_block[idx] = block @@ -441,12 +449,17 @@ class CoarseGrainedFSTPAllGatherSyncHandler: # print(f"name: {name}", flush=True) if name == "out_proj": self.FSTP_outs.append(child) - self.module_to_index[child] = idx + # self.module_to_index[child] = idx if name == "Wqkv": self.FSTP_wqkvs.append(child) - self.module_to_index[child] = idx + # self.module_to_index[child] = idx if isinstance(child, FSTPLinear): + self.module_to_index[child] = idx + self.block_module[idx][index] = child + self.FSTP_modules.append(child) self.index_to_fsdp_modules[idx].append(child) + self.module_name_index[child] = index + index = index + 1 else: continue @@ -457,6 +470,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler: for module in fsdp_modules: total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) self.FSTP_global_weights[module] = total_weight + self.FSTP_global_handle[module] = weight_handle self.block_handles[block].append(weight_handle) def _register_sync_parameters_hook(self) -> None: @@ -498,6 +512,19 @@ class CoarseGrainedFSTPAllGatherSyncHandler: del self.block_handles[block] for module in fsdp_modules: del self.FSTP_global_weights[module] + + + def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): + block_index = self.module_to_index[module] + if block_index != 0: + handler = self.FSTP_global_handle[module] + handler.wait() + + def _post_forward_hook_for_module(module: nn.Module, input, output): + if module in self.FSTP_global_weights: + del self.FSTP_global_weights[module] + if module in self.FSTP_global_handle: + del self.FSTP_global_handle[module] def _pre_backward_hook_for_wqkv(module: nn.Module, grad_output): block_index = self.module_to_index[module] @@ -531,15 +558,48 @@ class CoarseGrainedFSTPAllGatherSyncHandler: del self.block_handles[block] for module in fsdp_modules: del self.FSTP_global_weights[module] + + def _pre_backward_hook_for_module(module: nn.Module, grad_output): + block_index = self.module_to_index[module] + name_index = self.module_name_index[module] + if name_index == 4: + total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handler.wait() + self.FSTP_global_weights[module] = total_weight - for block in self.FSTP_blocks: - block.register_forward_pre_hook(_pre_forward_hook_for_block) - block.register_forward_hook(_post_forward_hook_for_block) - block.register_full_backward_pre_hook(_pre_backward_hook_for_block) - block.register_full_backward_hook(_post_backward_hook_for_block) + # start the all-gather for next module + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler + else: + handler = self.FSTP_global_handle[module] + handler.wait() + if name_index != 0: + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler + + def _post_backward_hook_for_module(module, grad_input, grad_output): + del self.FSTP_global_weights[module] + + # for block in self.FSTP_blocks: + # block.register_forward_pre_hook(_pre_forward_hook_for_block) + # block.register_forward_hook(_post_forward_hook_for_block) + # block.register_full_backward_pre_hook(_pre_backward_hook_for_block) + # block.register_full_backward_hook(_post_backward_hook_for_block) for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) # for wqkv in self.FSTP_wqkvs: # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) + + for module in self.FSTP_modules: + module.register_forward_pre_hook(_pre_forward_hook_for_module) + module.register_forward_hook(_post_forward_hook_for_module) + module.register_full_backward_pre_hook(_pre_backward_hook_for_module) + module.register_full_backward_hook(_post_backward_hook_for_module) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index da59803..572adba 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -110,7 +110,8 @@ def initialize_model(): model = wrap_FSDP_model(model) if gpc.config.parallel["tensor"]["mode"] == "fstp": - handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) + # handler = CoarseGrainedFSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) + handler = FSTPAllGatherSyncHandler(model, gpc.get_group(ParallelMode.TENSOR)) handler._register_sync_parameters_hook() gpc.config.fstp_handler = handler return model