diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index cb00d22..715fa46 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -258,10 +258,12 @@ class FSTPOverlapHandler: del self.weight_global_handle[module] if module in self.bias_global_handle: del self.bias_global_handle[module] - # if module in self.weight_global_output: - # del self.weight_global_output[module] - # if module in self.bias_global_output: - # del self.bias_global_output[module] + + def _clear_weight(module): + if module in self.weight_global_output: + del self.weight_global_output[module] + if module in self.bias_global_output: + del self.bias_global_output[module] def _post_forward_hook_for_embedding(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 self._all_gather_block_weight(0) @@ -290,6 +292,8 @@ class FSTPOverlapHandler: def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 _clear_handle(module) + if not self.model_checkpoint: + _clear_weight(module) def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): # pylint: disable=W0613 self._all_gather_module_weight(self.fstp_modules[-1]) @@ -313,6 +317,7 @@ class FSTPOverlapHandler: def _post_backward_hook_for_module(module, grad_input, grad_output): # pylint: disable=W0613 _clear_handle(module) + _clear_weight(module) # register forward hooks # 1. register post_forward_hook @embedding module to prefetch for block 0