feat(model/overlap_handler.py): release weight

pull/436/head
huangting4201 2023-11-14 11:30:26 +08:00
parent 74754397df
commit 3c07423151
1 changed files with 9 additions and 4 deletions

View File

@ -258,10 +258,12 @@ class FSTPOverlapHandler:
del self.weight_global_handle[module]
if module in self.bias_global_handle:
del self.bias_global_handle[module]
# if module in self.weight_global_output:
# del self.weight_global_output[module]
# if module in self.bias_global_output:
# del self.bias_global_output[module]
def _clear_weight(module):
if module in self.weight_global_output:
del self.weight_global_output[module]
if module in self.bias_global_output:
del self.bias_global_output[module]
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
self._all_gather_block_weight(0)
@ -290,6 +292,8 @@ class FSTPOverlapHandler:
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
_clear_handle(module)
if not self.model_checkpoint:
_clear_weight(module)
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
self._all_gather_module_weight(self.fstp_modules[-1])
@ -313,6 +317,7 @@ class FSTPOverlapHandler:
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
_clear_handle(module)
_clear_weight(module)
# register forward hooks
# 1. register post_forward_hook @embedding module to prefetch for block 0