mirror of https://github.com/InternLM/InternLM
add head overlap
parent
a5c6e457b9
commit
5c38cb6409
|
@ -596,8 +596,11 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
if block_index - 1 > 0:
|
||||
self._all_gather_block_weight(block_index - 1)
|
||||
|
||||
# def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||
# self._all_gather_block_weight(gpc.config.NUM_LAYER - 1)
|
||||
def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||
first_module = self.block_module[gpc.config.NUM_LAYER - 1][4]
|
||||
total_weight, weight_handler = all_gather_raw(first_module.weight, self.process_group, async_op=True)
|
||||
self.FSTP_global_handle[first_module] = weight_handler
|
||||
self.FSTP_global_weights[first_module] = total_weight
|
||||
|
||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||
block_index = self.block_to_index[block]
|
||||
|
@ -612,9 +615,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
name_index = self.module_name_index[module]
|
||||
if block_index != 0:
|
||||
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||
weight_handler = self.FSTP_global_handle[module]
|
||||
weight_handler.wait()
|
||||
self.FSTP_global_weights[module] = total_weight
|
||||
# self.FSTP_global_weights[module] = total_weight
|
||||
|
||||
# start the all-gather for next module
|
||||
next_module = self.block_module[block_index][name_index - 1]
|
||||
|
@ -651,8 +655,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
if module in self.FSTP_global_handle:
|
||||
del self.FSTP_global_handle[module]
|
||||
|
||||
# for head in self.head:
|
||||
# head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
# for block in self.FSTP_blocks:
|
||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||
|
|
Loading…
Reference in New Issue