mirror of https://github.com/InternLM/InternLM
support fine grained
parent
d1af0d6aee
commit
6408b944c2
|
@ -449,6 +449,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
|
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
|
||||||
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
|
||||||
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
|
||||||
|
self.head = []
|
||||||
|
|
||||||
# just want to share same for loop for ModuleList and Module
|
# just want to share same for loop for ModuleList and Module
|
||||||
if not isinstance(model, nn.ModuleList):
|
if not isinstance(model, nn.ModuleList):
|
||||||
|
@ -487,16 +488,18 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
index = index + 1
|
index = index + 1
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
elif isinstance(children, ScaleColumnParallelLinear):
|
||||||
|
self.head.append(children)
|
||||||
|
|
||||||
def _all_gather_block_weight(self, block_index: int):
|
def _all_gather_block_weight(self, block_index: int):
|
||||||
block = self.index_to_block[block_index]
|
block = self.index_to_block[block_index]
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
self.block_handles[block] = []
|
# self.block_handles[block] = []
|
||||||
for module in fsdp_modules:
|
for module in fsdp_modules:
|
||||||
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
self.FSTP_global_weights[module] = total_weight
|
self.FSTP_global_weights[module] = total_weight
|
||||||
self.FSTP_global_handle[module] = weight_handle
|
self.FSTP_global_handle[module] = weight_handle
|
||||||
self.block_handles[block].append(weight_handle)
|
# self.block_handles[block].append(weight_handle)
|
||||||
|
|
||||||
def _register_sync_parameters_hook(self) -> None:
|
def _register_sync_parameters_hook(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -558,6 +561,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
self._all_gather_block_weight(block_index - 1)
|
self._all_gather_block_weight(block_index - 1)
|
||||||
|
|
||||||
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
|
||||||
|
# import pdb; pdb.set_trace()
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
# if block_index == gpc.config.NUM_LAYER - 1:
|
# if block_index == gpc.config.NUM_LAYER - 1:
|
||||||
# # all gather weight for the last block
|
# # all gather weight for the last block
|
||||||
|
@ -571,11 +575,15 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
# handles = self.block_handles[block]
|
# handles = self.block_handles[block]
|
||||||
# for handle in handles:
|
# for handle in handles:
|
||||||
# handle.wait()
|
# handle.wait()
|
||||||
|
# if block_index == gpc.config.NUM_LAYER - 1:
|
||||||
|
# self._all_gather_block_weight(block_index)
|
||||||
# start the all-gather for next block
|
# start the all-gather for next block
|
||||||
if block_index - 1 > 0:
|
if block_index - 1 > 0:
|
||||||
self._all_gather_block_weight(block_index - 1)
|
self._all_gather_block_weight(block_index - 1)
|
||||||
|
|
||||||
|
# def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
|
||||||
|
# self._all_gather_block_weight(gpc.config.NUM_LAYER - 1)
|
||||||
|
|
||||||
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
|
||||||
block_index = self.block_to_index[block]
|
block_index = self.block_to_index[block]
|
||||||
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
fsdp_modules = self.index_to_fsdp_modules[block_index]
|
||||||
|
@ -588,40 +596,53 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
|
||||||
block_index = self.module_to_index[module]
|
block_index = self.module_to_index[module]
|
||||||
name_index = self.module_name_index[module]
|
name_index = self.module_name_index[module]
|
||||||
if block_index != 0:
|
if block_index != 0:
|
||||||
# if name_index == 4:
|
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
|
||||||
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
|
||||||
# weight_handler.wait()
|
weight_handler.wait()
|
||||||
# self.FSTP_global_weights[module] = total_weight
|
self.FSTP_global_weights[module] = total_weight
|
||||||
|
|
||||||
# # start the all-gather for next module
|
# start the all-gather for next module
|
||||||
# next_module = self.block_module[block_index][name_index - 1]
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
# next_module.weight, self.process_group, async_op=True
|
next_module.weight, self.process_group, async_op=True
|
||||||
# )
|
)
|
||||||
# self.FSTP_global_handle[next_module] = weights_handler
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
# else:
|
elif name_index == 0:
|
||||||
# handler = self.FSTP_global_handle[module]
|
|
||||||
# handler.wait()
|
|
||||||
# if name_index != 0:
|
|
||||||
# next_module = self.block_module[block_index][name_index - 1]
|
|
||||||
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
|
||||||
# next_module.weight, self.process_group, async_op=True
|
|
||||||
# )
|
|
||||||
# self.FSTP_global_handle[next_module] = weights_handler
|
|
||||||
if module in self.FSTP_global_handle:
|
|
||||||
handler = self.FSTP_global_handle[module]
|
handler = self.FSTP_global_handle[module]
|
||||||
handler.wait()
|
handler.wait()
|
||||||
|
|
||||||
|
if block_index - 1 > 0:
|
||||||
|
next_module = self.block_module[block_index - 1][4]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
else:
|
||||||
|
handler = self.FSTP_global_handle[module]
|
||||||
|
handler.wait()
|
||||||
|
if name_index != 0:
|
||||||
|
next_module = self.block_module[block_index][name_index - 1]
|
||||||
|
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
|
||||||
|
next_module.weight, self.process_group, async_op=True
|
||||||
|
)
|
||||||
|
self.FSTP_global_handle[next_module] = weights_handler
|
||||||
|
# if module in self.FSTP_global_handle:
|
||||||
|
# handler = self.FSTP_global_handle[module]
|
||||||
|
# handler.wait()
|
||||||
|
|
||||||
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
def _post_backward_hook_for_module(module, grad_input, grad_output):
|
||||||
if module in self.FSTP_global_weights:
|
if module in self.FSTP_global_weights:
|
||||||
del self.FSTP_global_weights[module]
|
del self.FSTP_global_weights[module]
|
||||||
if module in self.FSTP_global_handle:
|
if module in self.FSTP_global_handle:
|
||||||
del self.FSTP_global_handle[module]
|
del self.FSTP_global_handle[module]
|
||||||
|
|
||||||
for block in self.FSTP_blocks:
|
# for head in self.head:
|
||||||
|
# head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||||
|
|
||||||
|
# for block in self.FSTP_blocks:
|
||||||
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
|
||||||
# block.register_forward_hook(_post_forward_hook_for_block)
|
# block.register_forward_hook(_post_forward_hook_for_block)
|
||||||
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
|
||||||
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
# block.register_full_backward_hook(_post_backward_hook_for_block)
|
||||||
|
|
||||||
for out_proj in self.FSTP_outs:
|
for out_proj in self.FSTP_outs:
|
||||||
|
|
Loading…
Reference in New Issue