diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 35d8a59..8462def 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -75,6 +75,8 @@ class FSTPOverlapHandler: if child.bias is not None: setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias") + self.num_blocks = len(self.index_to_fstp_modules) + self._initialize_memory_pool() self._register_sync_parameters_hook() @@ -219,15 +221,25 @@ class FSTPOverlapHandler: self._all_gather_block_weight_memory_pool(block_index - 1) else: # start the all-gather for next block - if block_index + 1 < gpc.config.NUM_LAYER: + if block_index + 1 < self.num_blocks: self._all_gather_block_weight_memory_pool(block_index + 1) def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613 - handle = self.fstp_global_handle[module] - handle.wait() - if module.bias is not None: - bias_handle = self.bias_global_handle[module] - bias_handle.wait() + if module in self.fstp_global_handle: + handle = self.fstp_global_handle[module] + handle.wait() + if module.bias is not None: + bias_handle = self.bias_global_handle[module] + bias_handle.wait() + else: + weight_handle = all_gather_raw_memory_pool( + module.weight, + self.process_group, + async_op=True, + module=module, + ) + self.fstp_global_handle[module] = weight_handle + weight_handle.wait() def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613 if module in self.fstp_global_handle: @@ -245,12 +257,22 @@ class FSTPOverlapHandler: def _pre_backward_hook_for_head(module: nn.Module, grad_output): if self.is_forward is False: - self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1) + self._all_gather_block_weight_memory_pool(self.num_blocks - 1) def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613 # wait handle for current module - weight_handle = self.fstp_global_handle[module] - weight_handle.wait() + if module in self.fstp_global_handle: + weight_handle = self.fstp_global_handle[module] + weight_handle.wait() + else: + weight_handle = all_gather_raw_memory_pool( + module.weight, + self.process_group, + async_op=True, + module=module, + ) + self.fstp_global_handle[module] = weight_handle + weight_handle.wait() # start the all-gather for next module module_index = self.fstp_modules.index(module)