mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): fix head post backward hook when activation
parent
97dcefc389
commit
5d8313693b
|
@ -244,6 +244,7 @@ class FSTPOverlapHandler:
|
|||
self.fstp_global_handle[first_backward_module] = weight_handle
|
||||
|
||||
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
|
||||
if self.is_forward is False:
|
||||
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||
|
@ -276,7 +277,7 @@ class FSTPOverlapHandler:
|
|||
for embedding in self.embedding:
|
||||
embedding.register_forward_hook(_post_forward_hook_for_embedding)
|
||||
|
||||
if self.model_checkpoint and self.is_forward is False:
|
||||
if self.model_checkpoint:
|
||||
for head in self.head:
|
||||
head.register_full_backward_pre_hook(_pre_backward_hook_for_head)
|
||||
|
||||
|
@ -291,7 +292,7 @@ class FSTPOverlapHandler:
|
|||
# 1. register post_backward_hook @head module to prefetch for the last block's last module
|
||||
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
|
||||
# 3. register post_backward_hook @fstp_module to release resource
|
||||
if gpc.config.model.checkpoint is False:
|
||||
if self.model_checkpoint is False:
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
|
|
Loading…
Reference in New Issue