diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 3e37863..56929ee 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -520,6 +520,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler: for handle in handles: handle.wait() + # start the all-gather for next block + if block_index - 1 >= 0: + self._all_gather_block_weight(block_index - 1) + def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): block_index = self.block_to_index[block] fsdp_modules = self.index_to_fsdp_modules[block_index] @@ -537,5 +541,5 @@ class CoarseGrainedFSTPAllGatherSyncHandler: for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) - for wqkv in self.FSTP_wqkvs: - wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv) + # for wqkv in self.FSTP_wqkvs: + # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)