support fine grained

pull/407/head
yingtongxiong 2023-10-17 15:14:39 +08:00
parent d1af0d6aee
commit 6408b944c2
1 changed files with 48 additions and 27 deletions

View File

@ -449,7 +449,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.head = []
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
@ -487,16 +488,18 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
index = index + 1
else:
continue
elif isinstance(children, ScaleColumnParallelLinear):
self.head.append(children)
def _all_gather_block_weight(self, block_index: int):
block = self.index_to_block[block_index]
fsdp_modules = self.index_to_fsdp_modules[block_index]
self.block_handles[block] = []
# self.block_handles[block] = []
for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[module] = total_weight
self.FSTP_global_handle[module] = weight_handle
self.block_handles[block].append(weight_handle)
# self.block_handles[block].append(weight_handle)
def _register_sync_parameters_hook(self) -> None:
"""
@ -558,6 +561,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self._all_gather_block_weight(block_index - 1)
def _pre_backward_hook_for_block(block: nn.Module, grad_output):
# import pdb; pdb.set_trace()
block_index = self.block_to_index[block]
# if block_index == gpc.config.NUM_LAYER - 1:
# # all gather weight for the last block
@ -571,10 +575,14 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# handles = self.block_handles[block]
# for handle in handles:
# handle.wait()
# if block_index == gpc.config.NUM_LAYER - 1:
# self._all_gather_block_weight(block_index)
# start the all-gather for next block
if block_index - 1 > 0:
self._all_gather_block_weight(block_index - 1)
# def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
# self._all_gather_block_weight(gpc.config.NUM_LAYER - 1)
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
block_index = self.block_to_index[block]
@ -588,45 +596,58 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
block_index = self.module_to_index[module]
name_index = self.module_name_index[module]
if block_index != 0:
# if name_index == 4:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
# weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight
if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
weight_handler.wait()
self.FSTP_global_weights[module] = total_weight
# # start the all-gather for next module
# next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True
# )
# self.FSTP_global_handle[next_module] = weights_handler
# else:
# handler = self.FSTP_global_handle[module]
# handler.wait()
# if name_index != 0:
# next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True
# )
# self.FSTP_global_handle[next_module] = weights_handler
if module in self.FSTP_global_handle:
# start the all-gather for next module
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
elif name_index == 0:
handler = self.FSTP_global_handle[module]
handler.wait()
if block_index - 1 > 0:
next_module = self.block_module[block_index - 1][4]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
# if module in self.FSTP_global_handle:
# handler = self.FSTP_global_handle[module]
# handler.wait()
def _post_backward_hook_for_module(module, grad_input, grad_output):
if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module]
# for head in self.head:
# head.register_full_backward_hook(_post_backward_hook_for_head)
for block in self.FSTP_blocks:
# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block)
block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block)
for out_proj in self.FSTP_outs:
out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)
# for wqkv in self.FSTP_wqkvs:
# wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)