feat(model/linear.py): change pre backward from wqkv to block

pull/407/head
huangting4201 2023-10-13 11:10:23 +08:00
parent d0b1346993
commit d0f0c22cac
1 changed files with 6 additions and 2 deletions

View File

@ -520,6 +520,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for handle in handles:
handle.wait()
# start the all-gather for next block
if block_index - 1 >= 0:
self._all_gather_block_weight(block_index - 1)
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
block_index = self.block_to_index[block]
fsdp_modules = self.index_to_fsdp_modules[block_index]
@ -537,5 +541,5 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
for wqkv in self.FSTP_wqkvs:
wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
# for wqkv in self.FSTP_wqkvs:
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)