support fine grained

pull/407/head
yingtongxiong 2023-10-17 15:14:39 +08:00
parent d1af0d6aee
commit 6408b944c2
1 changed files with 48 additions and 27 deletions

View File

@ -449,6 +449,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules
self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name
self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module}
self.head = []
# just want to share same for loop for ModuleList and Module # just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList): if not isinstance(model, nn.ModuleList):
@ -487,16 +488,18 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
index = index + 1 index = index + 1
else: else:
continue continue
elif isinstance(children, ScaleColumnParallelLinear):
self.head.append(children)
def _all_gather_block_weight(self, block_index: int): def _all_gather_block_weight(self, block_index: int):
block = self.index_to_block[block_index] block = self.index_to_block[block_index]
fsdp_modules = self.index_to_fsdp_modules[block_index] fsdp_modules = self.index_to_fsdp_modules[block_index]
self.block_handles[block] = [] # self.block_handles[block] = []
for module in fsdp_modules: for module in fsdp_modules:
total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True)
self.FSTP_global_weights[module] = total_weight self.FSTP_global_weights[module] = total_weight
self.FSTP_global_handle[module] = weight_handle self.FSTP_global_handle[module] = weight_handle
self.block_handles[block].append(weight_handle) # self.block_handles[block].append(weight_handle)
def _register_sync_parameters_hook(self) -> None: def _register_sync_parameters_hook(self) -> None:
""" """
@ -558,6 +561,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
self._all_gather_block_weight(block_index - 1) self._all_gather_block_weight(block_index - 1)
def _pre_backward_hook_for_block(block: nn.Module, grad_output): def _pre_backward_hook_for_block(block: nn.Module, grad_output):
# import pdb; pdb.set_trace()
block_index = self.block_to_index[block] block_index = self.block_to_index[block]
# if block_index == gpc.config.NUM_LAYER - 1: # if block_index == gpc.config.NUM_LAYER - 1:
# # all gather weight for the last block # # all gather weight for the last block
@ -571,11 +575,15 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
# handles = self.block_handles[block] # handles = self.block_handles[block]
# for handle in handles: # for handle in handles:
# handle.wait() # handle.wait()
# if block_index == gpc.config.NUM_LAYER - 1:
# self._all_gather_block_weight(block_index)
# start the all-gather for next block # start the all-gather for next block
if block_index - 1 > 0: if block_index - 1 > 0:
self._all_gather_block_weight(block_index - 1) self._all_gather_block_weight(block_index - 1)
# def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output):
# self._all_gather_block_weight(gpc.config.NUM_LAYER - 1)
def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output):
block_index = self.block_to_index[block] block_index = self.block_to_index[block]
fsdp_modules = self.index_to_fsdp_modules[block_index] fsdp_modules = self.index_to_fsdp_modules[block_index]
@ -588,40 +596,53 @@ class CoarseGrainedFSTPAllGatherSyncHandler:
block_index = self.module_to_index[module] block_index = self.module_to_index[module]
name_index = self.module_name_index[module] name_index = self.module_name_index[module]
if block_index != 0: if block_index != 0:
# if name_index == 4: if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1:
# total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True)
# weight_handler.wait() weight_handler.wait()
# self.FSTP_global_weights[module] = total_weight self.FSTP_global_weights[module] = total_weight
# # start the all-gather for next module # start the all-gather for next module
# next_module = self.block_module[block_index][name_index - 1] next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True next_module.weight, self.process_group, async_op=True
# ) )
# self.FSTP_global_handle[next_module] = weights_handler self.FSTP_global_handle[next_module] = weights_handler
# else: elif name_index == 0:
# handler = self.FSTP_global_handle[module]
# handler.wait()
# if name_index != 0:
# next_module = self.block_module[block_index][name_index - 1]
# self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
# next_module.weight, self.process_group, async_op=True
# )
# self.FSTP_global_handle[next_module] = weights_handler
if module in self.FSTP_global_handle:
handler = self.FSTP_global_handle[module] handler = self.FSTP_global_handle[module]
handler.wait() handler.wait()
if block_index - 1 > 0:
next_module = self.block_module[block_index - 1][4]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
else:
handler = self.FSTP_global_handle[module]
handler.wait()
if name_index != 0:
next_module = self.block_module[block_index][name_index - 1]
self.FSTP_global_weights[next_module], weights_handler = all_gather_raw(
next_module.weight, self.process_group, async_op=True
)
self.FSTP_global_handle[next_module] = weights_handler
# if module in self.FSTP_global_handle:
# handler = self.FSTP_global_handle[module]
# handler.wait()
def _post_backward_hook_for_module(module, grad_input, grad_output): def _post_backward_hook_for_module(module, grad_input, grad_output):
if module in self.FSTP_global_weights: if module in self.FSTP_global_weights:
del self.FSTP_global_weights[module] del self.FSTP_global_weights[module]
if module in self.FSTP_global_handle: if module in self.FSTP_global_handle:
del self.FSTP_global_handle[module] del self.FSTP_global_handle[module]
for block in self.FSTP_blocks: # for head in self.head:
# head.register_full_backward_hook(_post_backward_hook_for_head)
# for block in self.FSTP_blocks:
# block.register_forward_pre_hook(_pre_forward_hook_for_block) # block.register_forward_pre_hook(_pre_forward_hook_for_block)
# block.register_forward_hook(_post_forward_hook_for_block) # block.register_forward_hook(_post_forward_hook_for_block)
block.register_full_backward_pre_hook(_pre_backward_hook_for_block) # block.register_full_backward_pre_hook(_pre_backward_hook_for_block)
# block.register_full_backward_hook(_post_backward_hook_for_block) # block.register_full_backward_hook(_post_backward_hook_for_block)
for out_proj in self.FSTP_outs: for out_proj in self.FSTP_outs: