mirror of https://github.com/InternLM/InternLM
Merge pull request #6 from yingtongxiong/fstp/overlap-support-pp
feat(model/overlap_handler.py): fix overlap hander to support pp(non-…pull/436/head
commit
bc5a85c624
|
@ -75,6 +75,8 @@ class FSTPOverlapHandler:
|
|||
if child.bias is not None:
|
||||
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
|
||||
|
||||
self.num_blocks = len(self.index_to_fstp_modules)
|
||||
|
||||
self._initialize_memory_pool()
|
||||
self._register_sync_parameters_hook()
|
||||
|
||||
|
@ -219,15 +221,25 @@ class FSTPOverlapHandler:
|
|||
self._all_gather_block_weight_memory_pool(block_index - 1)
|
||||
else:
|
||||
# start the all-gather for next block
|
||||
if block_index + 1 < gpc.config.NUM_LAYER:
|
||||
if block_index + 1 < self.num_blocks:
|
||||
self._all_gather_block_weight_memory_pool(block_index + 1)
|
||||
|
||||
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
handle = self.fstp_global_handle[module]
|
||||
handle.wait()
|
||||
if module.bias is not None:
|
||||
bias_handle = self.bias_global_handle[module]
|
||||
bias_handle.wait()
|
||||
if module in self.fstp_global_handle:
|
||||
handle = self.fstp_global_handle[module]
|
||||
handle.wait()
|
||||
if module.bias is not None:
|
||||
bias_handle = self.bias_global_handle[module]
|
||||
bias_handle.wait()
|
||||
else:
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
weight_handle.wait()
|
||||
|
||||
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
|
||||
if module in self.fstp_global_handle:
|
||||
|
@ -245,12 +257,22 @@ class FSTPOverlapHandler:
|
|||
|
||||
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
|
||||
if self.is_forward is False:
|
||||
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
|
||||
self._all_gather_block_weight_memory_pool(self.num_blocks - 1)
|
||||
|
||||
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
|
||||
# wait handle for current module
|
||||
weight_handle = self.fstp_global_handle[module]
|
||||
weight_handle.wait()
|
||||
if module in self.fstp_global_handle:
|
||||
weight_handle = self.fstp_global_handle[module]
|
||||
weight_handle.wait()
|
||||
else:
|
||||
weight_handle = all_gather_raw_memory_pool(
|
||||
module.weight,
|
||||
self.process_group,
|
||||
async_op=True,
|
||||
module=module,
|
||||
)
|
||||
self.fstp_global_handle[module] = weight_handle
|
||||
weight_handle.wait()
|
||||
|
||||
# start the all-gather for next module
|
||||
module_index = self.fstp_modules.index(module)
|
||||
|
|
Loading…
Reference in New Issue