add head overlap

pull/407/head
yingtongxiong 2023-10-17 15:38:24 +08:00
parent a5c6e457b9
commit 5c38cb6409
1 changed files with 10 additions and 6 deletions

View File

@ -596,8 +596,11 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
if block_index - 1 > 0:
self._all_gather_block_weight(block_index - 1)
# def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
# self._all_gather_block_weight(gpc.config.NUM_LAYER - 1)
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
first_module = self.block_module[gpc.config.NUM_LAYER - 1][4]
total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True)
self.FSTP_global_handle[first_module] = weight_handler
self.FSTP_global_weights[first_module] = total_weight
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
block_index = self.block_to_index[block]
@ -612,9 +615,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
name_index = self.module_name_index[module]
if block_index != 0:
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler = self.FSTP_global_handle[module]
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# self.FSTP_global_weights[module] = total_weight
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
@ -651,8 +655,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
# for head in self.head:
# head.register_full_backward_hook(_post_backward_hook_for_head)
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)
# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)