From d0f0c22cace187e62890aa34c3a0595115ceb394 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Fri, 13 Oct 2023 11:10:23 +0800 Subject: [PATCH] feat(model/linear.py): change pre backward from wqkv to block --- internlm/model/linear.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 3e37863..56929ee 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -520,6 +520,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler: for handle in handles: handle.wait() + # start the all-gather for next block + if block_index - 1 >= 0: + self._all_gather_block_weight(block_index - 1) + def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): block_index = self.block_to_index[block] fsdp_modules = self.index_to_fsdp_modules[block_index] @@ -537,5 +541,5 @@ class CoarseGrainedFSTPAllGatherSyncHandler: for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) - for wqkv in self.FSTP_wqkvs: - wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) + # for wqkv in self.FSTP_wqkvs: + # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)