mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): release weight
parent
74754397df
commit
3c07423151
|
@ -258,10 +258,12 @@ class FSTPOverlapHandler:
|
|||
del self.weight_global_handle[module]
|
||||
if module in self.bias_global_handle:
|
||||
del self.bias_global_handle[module]
|
||||
# if module in self.weight_global_output:
|
||||
# del self.weight_global_output[module]
|
||||
# if module in self.bias_global_output:
|
||||
# del self.bias_global_output[module]
|
||||
|
||||
def _clear_weight(module):
|
||||
if module in self.weight_global_output:
|
||||
del self.weight_global_output[module]
|
||||
if module in self.bias_global_output:
|
||||
del self.bias_global_output[module]
|
||||
|
||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
self._all_gather_block_weight(0)
|
||||
|
@ -290,6 +292,8 @@ class FSTPOverlapHandler:
|
|||
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
_clear_handle(module)
|
||||
if not self.model_checkpoint:
|
||||
_clear_weight(module)
|
||||
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
||||
self._all_gather_module_weight(self.fstp_modules[-1])
|
||||
|
@ -313,6 +317,7 @@ class FSTPOverlapHandler:
|
|||
|
||||
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
||||
_clear_handle(module)
|
||||
_clear_weight(module)
|
||||
|
||||
# register forward hooks
|
||||
# 1. register post_forward_hook @embedding module to prefetch for block 0
|
||||
|
|
Loading…
Reference in New Issue