diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 098fc8c..5cef92f 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -244,7 +244,8 @@ class FSTPOverlapHandler: self.fstp_global_handle[first_backward_module] = weight_handle def _pre_backward_hook_for_head(module: nn.Module, grad_output): - self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1) + if self.is_forward is False: + self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1) def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 # wait handle for current module @@ -276,7 +277,7 @@ class FSTPOverlapHandler: for embedding in self.embedding: embedding.register_forward_hook(_post_forward_hook_for_embedding) - if self.model_checkpoint and self.is_forward is False: + if self.model_checkpoint: for head in self.head: head.register_full_backward_pre_hook(_pre_backward_hook_for_head) @@ -291,7 +292,7 @@ class FSTPOverlapHandler: # 1. register post_backward_hook @head module to prefetch for the last block's last module # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module # 3. register post_backward_hook @fstp_module to release resource - if gpc.config.model.checkpoint is False: + if self.model_checkpoint is False: for head in self.head: head.register_full_backward_hook(_post_backward_hook_for_head)