mirror of https://github.com/InternLM/InternLM
feat(model/linear.py): change pre backward from wqkv to block
parent
d0b1346993
commit
d0f0c22cac
|
@ -520,6 +520,10 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
handle.wait()
|
handle.wait()
|
||||||
|
|
||||||
|
# start the all-gather for next block
|
||||||
|
if block_index - 1 >= 0:
|
||||||
|
self._all_gather_block_weight(block_index - 1)
|
||||||
|
|
||||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
|
@ -537,5 +541,5 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
for out_proj in self.FSTP_outs:
|
for out_proj in self.FSTP_outs:
|
||||||
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
|
||||||
|
|
||||||
for wqkv in self.FSTP_wqkvs:
|
# for wqkv in self.FSTP_wqkvs:
|
||||||
wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)
|
||||||
|
|
Loading…
Reference in New Issue