feat(model/overlap_handler.py): fix overlap hander to support pp(non-interleaved)

pull/456/head
huangting4201 2023-10-27 20:04:23 +08:00
parent aa3840fc38
commit 3778c66660
1 changed files with 31 additions and 9 deletions

View File

@ -75,6 +75,8 @@ class FSTPOverlapHandler:
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
self.num_blocks = len(self.index_to_fstp_modules)
self._initialize_memory_pool()
self._register_sync_parameters_hook()
@ -219,15 +221,25 @@ class FSTPOverlapHandler:
self._all_gather_block_weight_memory_pool(block_index - 1)
else:
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
if block_index + 1 < self.num_blocks:
self._all_gather_block_weight_memory_pool(block_index + 1)
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
if module in self.fstp_global_handle:
handle = self.fstp_global_handle[module]
handle.wait()
if module.bias is not None:
bias_handle = self.bias_global_handle[module]
bias_handle.wait()
else:
weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
async_op=True,
module=module,
)
self.fstp_global_handle[module] = weight_handle
weight_handle.wait()
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
if module in self.fstp_global_handle:
@ -245,12 +257,22 @@ class FSTPOverlapHandler:
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
if self.is_forward is False:
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
self._all_gather_block_weight_memory_pool(self.num_blocks - 1)
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
# wait handle for current module
if module in self.fstp_global_handle:
weight_handle = self.fstp_global_handle[module]
weight_handle.wait()
else:
weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
async_op=True,
module=module,
)
self.fstp_global_handle[module] = weight_handle
weight_handle.wait()
# start the all-gather for next module
module_index = self.fstp_modules.index(module)