feat(model/overlap_handler.py): fix head post backward hook when activation

pull/456/head
huangting4201 2023-10-24 17:29:09 +08:00
parent 97dcefc389
commit 5d8313693b
1 changed files with 4 additions and 3 deletions

View File

@ -244,7 +244,8 @@ class FSTPOverlapHandler:
self.fstp_global_handle[first_backward_module] = weight_handle
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
if self.is_forward is False:
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
# wait handle for current module
@ -276,7 +277,7 @@ class FSTPOverlapHandler:
for embedding in self.embedding:
embedding.register_forward_hook(_post_forward_hook_for_embedding)
if self.model_checkpoint and self.is_forward is False:
if self.model_checkpoint:
for head in self.head:
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
@ -291,7 +292,7 @@ class FSTPOverlapHandler:
# 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @fstp_module to release resource
if gpc.config.model.checkpoint is False:
if self.model_checkpoint is False:
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)