Merge pull request #6 from yingtongxiong/fstp/overlap-support-pp

feat(model/overlap_handler.py): fix overlap hander to support pp(non-…
pull/436/head
ytxiong 2023-10-27 20:32:44 +08:00 committed by GitHub
commit bc5a85c624
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 9 deletions

View File

@ -75,6 +75,8 @@ class FSTPOverlapHandler:
if child.bias is not None:
setattr(child.bias, "_fstp_reduce_scatter_str", f"{_full_name}.bias")
self.num_blocks = len(self.index_to_fstp_modules)
self._initialize_memory_pool()
self._register_sync_parameters_hook()
@ -219,15 +221,25 @@ class FSTPOverlapHandler:
self._all_gather_block_weight_memory_pool(block_index - 1)
else:
# start the all-gather for next block
if block_index + 1 < gpc.config.NUM_LAYER:
if block_index + 1 < self.num_blocks:
self._all_gather_block_weight_memory_pool(block_index + 1)
def _pre_forward_hook_for_module(module: nn.Module, inputs: Any): # pylint: disable=W0613
handle = self.fstp_global_handle[module]
handle.wait()
if module.bias is not None:
bias_handle = self.bias_global_handle[module]
bias_handle.wait()
if module in self.fstp_global_handle:
handle = self.fstp_global_handle[module]
handle.wait()
if module.bias is not None:
bias_handle = self.bias_global_handle[module]
bias_handle.wait()
else:
weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
async_op=True,
module=module,
)
self.fstp_global_handle[module] = weight_handle
weight_handle.wait()
def _post_forward_hook_for_module(module: nn.Module, inputs: Any, output: Any): # pylint: disable=W0613
if module in self.fstp_global_handle:
@ -245,12 +257,22 @@ class FSTPOverlapHandler:
def _pre_backward_hook_for_head(module: nn.Module, grad_output):
if self.is_forward is False:
self._all_gather_block_weight_memory_pool(gpc.config.NUM_LAYER - 1)
self._all_gather_block_weight_memory_pool(self.num_blocks - 1)
def _pre_backward_hook_for_module(module: nn.Module, grad_output): # pylint: disable=W0613
# wait handle for current module
weight_handle = self.fstp_global_handle[module]
weight_handle.wait()
if module in self.fstp_global_handle:
weight_handle = self.fstp_global_handle[module]
weight_handle.wait()
else:
weight_handle = all_gather_raw_memory_pool(
module.weight,
self.process_group,
async_op=True,
module=module,
)
self.fstp_global_handle[module] = weight_handle
weight_handle.wait()
# start the all-gather for next module
module_index = self.fstp_modules.index(module)