mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): change pre backward from wqkv to block
parent
d0b1346993
commit
d0f0c22cac
|
@ -520,6 +520,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
# start the all-gather for next block
|
||||
if block_index - 1 >= 0:
|
||||
self._all_gather_block_weight(block_index - 1)
|
||||
|
||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||
block_index = self.block_to_index[block]
|
||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||
|
@ -537,5 +541,5 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
|||
for out_proj in self.FSTP_outs:
|
||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||
|
||||
for wqkv in self.FSTP_wqkvs:
|
||||
wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||
# for wqkv in self.FSTP_wqkvs:
|
||||
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||
|
|
Loading…
Reference in New Issue