mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): release weight
parent
74754397df
commit
3c07423151
|
@ -258,10 +258,12 @@ class FSTPOverlapHandler:
|
||||||
del self.weight_global_handle[module]
|
del self.weight_global_handle[module]
|
||||||
if module in self.bias_global_handle:
|
if module in self.bias_global_handle:
|
||||||
del self.bias_global_handle[module]
|
del self.bias_global_handle[module]
|
||||||
# if module in self.weight_global_output:
|
|
||||||
# del self.weight_global_output[module]
|
def _clear_weight(module):
|
||||||
# if module in self.bias_global_output:
|
if module in self.weight_global_output:
|
||||||
# del self.bias_global_output[module]
|
del self.weight_global_output[module]
|
||||||
|
if module in self.bias_global_output:
|
||||||
|
del self.bias_global_output[module]
|
||||||
|
|
||||||
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||||
self._all_gather_block_weight(0)
|
self._all_gather_block_weight(0)
|
||||||
|
@ -290,6 +292,8 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||||
_clear_handle(module)
|
_clear_handle(module)
|
||||||
|
if not self.model_checkpoint:
|
||||||
|
_clear_weight(module)
|
||||||
|
|
||||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613
|
||||||
self._all_gather_module_weight(self.fstp_modules[-1])
|
self._all_gather_module_weight(self.fstp_modules[-1])
|
||||||
|
@ -313,6 +317,7 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613
|
||||||
_clear_handle(module)
|
_clear_handle(module)
|
||||||
|
_clear_weight(module)
|
||||||
|
|
||||||
# register forward hooks
|
# register forward hooks
|
||||||
# 1. register post_forward_hook @embedding module to prefetch for block 0
|
# 1. register post_forward_hook @embedding module to prefetch for block 0
|
||||||
|
|
Loading…
Reference in New Issue