From 5d8313693b01769a4239d7938667c3d01a5a3d90 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Tue, 24 Oct 2023 17:29:09 +0800 Subject: [PATCH] feat(model/overlap_handler.py): fix head post backward hook when activation --- internlm/model/overlap_handler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 098fc8c..5cef92f 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -244,7 +244,8 @@ class FSTPOverlapHandler: self.fstp_global_handle[first_backward_module] = weight_handle def _pre_backward_hook_for_head(module: nn.Module, grad_output): - self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1) + if self.is_forward is False: + self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1) def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 # wait handle for current module @@ -276,7 +277,7 @@ class FSTPOverlapHandler: for embedding in self.embedding: embedding.register_forward_hook(_post_forward_hook_for_embedding) - if self.model_checkpoint and self.is_forward is False: + if self.model_checkpoint: for head in self.head: head.register_full_backward_pre_hook(_pre_backward_hook_for_head) @@ -291,7 +292,7 @@ class FSTPOverlapHandler: # 1. register post_backward_hook @head module to prefetch for the last block's last module # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module # 3. register post_backward_hook @fstp_module to release resource - if gpc.config.model.checkpoint is False: + if self.model_checkpoint is False: for head in self.head: head.register_full_backward_hook(_post_backward_hook_for_head)