From 5c38cb64095513c3740e9618c41e143608169ab5 Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Tue, 17 Oct 2023 15:38:24 +0800 Subject: [PATCH] add head overlap --- internlm/model/linear.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 16b0c85..71bdf05 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -596,8 +596,11 @@ class CoarseGrainedFSTPAllGatherSyncHandler: if block_index - 1 > 0: self._all_gather_block_weight(block_index - 1) - # def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): - # self._all_gather_block_weight(gpc.config.NUM_LAYER - 1) + def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): + first_module = self.block_module[gpc.config.NUM_LAYER - 1][4] + total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True) + self.FSTP_global_handle[first_module] = weight_handler + self.FSTP_global_weights[first_module] = total_weight def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): block_index = self.block_to_index[block] @@ -612,9 +615,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler: name_index = self.module_name_index[module] if block_index != 0: if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: - total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handler = self.FSTP_global_handle[module] weight_handler.wait() - self.FSTP_global_weights[module] = total_weight + # self.FSTP_global_weights[module] = total_weight # start the all-gather for next module next_module = self.block_module[block_index][name_index - 1] @@ -651,8 +655,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler: if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] - # for head in self.head: - # head.register_full_backward_hook(_post_backward_hook_for_head) + for head in self.head: + head.register_full_backward_hook(_post_backward_hook_for_head) # for block in self.FSTP_blocks: # block.register_forward_pre_hook(_pre_forward_hook_for_block)