diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 8e19ab6..e8727ac 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -449,7 +449,8 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self.index_to_fsdp_modules = dict() # key: transformer block index; value: fsdp modules self.module_name_index = dict() # key: FSTP module; value: the name in index in self.module_name self.block_module = dict() # key: transformer block index; value: {name_index: FSTP module} - + self.head = [] + # just want to share same for loop for ModuleList and Module if not isinstance(model, nn.ModuleList): model = [model] @@ -487,16 +488,18 @@ class CoarseGrainedFSTPAllGatherSyncHandler: index = index + 1 else: continue + elif isinstance(children, ScaleColumnParallelLinear): + self.head.append(children) def _all_gather_block_weight(self, block_index: int): block = self.index_to_block[block_index] fsdp_modules = self.index_to_fsdp_modules[block_index] - self.block_handles[block] = [] + # self.block_handles[block] = [] for module in fsdp_modules: total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) self.FSTP_global_weights[module] = total_weight self.FSTP_global_handle[module] = weight_handle - self.block_handles[block].append(weight_handle) + # self.block_handles[block].append(weight_handle) def _register_sync_parameters_hook(self) -> None: """ @@ -558,6 +561,7 @@ class CoarseGrainedFSTPAllGatherSyncHandler: self._all_gather_block_weight(block_index - 1) def _pre_backward_hook_for_block(block: nn.Module, grad_output): + # import pdb; pdb.set_trace() block_index = self.block_to_index[block] # if block_index == gpc.config.NUM_LAYER - 1: # # all gather weight for the last block @@ -571,10 +575,14 @@ class CoarseGrainedFSTPAllGatherSyncHandler: # handles = self.block_handles[block] # for handle in handles: # handle.wait() - + # if block_index == gpc.config.NUM_LAYER - 1: + # self._all_gather_block_weight(block_index) # start the all-gather for next block if block_index - 1 > 0: self._all_gather_block_weight(block_index - 1) + + # def _post_backward_hook_for_head(module: nn.Module, grad_input, grad_output): + # self._all_gather_block_weight(gpc.config.NUM_LAYER - 1) def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): block_index = self.block_to_index[block] @@ -588,45 +596,58 @@ class CoarseGrainedFSTPAllGatherSyncHandler: block_index = self.module_to_index[module] name_index = self.module_name_index[module] if block_index != 0: - # if name_index == 4: - # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - # weight_handler.wait() - # self.FSTP_global_weights[module] = total_weight + if name_index == 4 and block_index == gpc.config.NUM_LAYER - 1: + total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + weight_handler.wait() + self.FSTP_global_weights[module] = total_weight - # # start the all-gather for next module - # next_module = self.block_module[block_index][name_index - 1] - # self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - # next_module.weight, self.process_group, async_op=True - # ) - # self.FSTP_global_handle[next_module] = weights_handler - # else: - # handler = self.FSTP_global_handle[module] - # handler.wait() - # if name_index != 0: - # next_module = self.block_module[block_index][name_index - 1] - # self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - # next_module.weight, self.process_group, async_op=True - # ) - # self.FSTP_global_handle[next_module] = weights_handler - if module in self.FSTP_global_handle: + # start the all-gather for next module + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler + elif name_index == 0: handler = self.FSTP_global_handle[module] handler.wait() + + if block_index - 1 > 0: + next_module = self.block_module[block_index - 1][4] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler + else: + handler = self.FSTP_global_handle[module] + handler.wait() + if name_index != 0: + next_module = self.block_module[block_index][name_index - 1] + self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + next_module.weight, self.process_group, async_op=True + ) + self.FSTP_global_handle[next_module] = weights_handler + # if module in self.FSTP_global_handle: + # handler = self.FSTP_global_handle[module] + # handler.wait() def _post_backward_hook_for_module(module, grad_input, grad_output): if module in self.FSTP_global_weights: del self.FSTP_global_weights[module] if module in self.FSTP_global_handle: del self.FSTP_global_handle[module] + + # for head in self.head: + # head.register_full_backward_hook(_post_backward_hook_for_head) - for block in self.FSTP_blocks: + # for block in self.FSTP_blocks: # block.register_forward_pre_hook(_pre_forward_hook_for_block) # block.register_forward_hook(_post_forward_hook_for_block) - block.register_full_backward_pre_hook(_pre_backward_hook_for_block) + # block.register_full_backward_pre_hook(_pre_backward_hook_for_block) # block.register_full_backward_hook(_post_backward_hook_for_block) for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj) - + # for wqkv in self.FSTP_wqkvs: # wqkv.register_full_backward_pre_hook(_pre_backward_hook_for_wqkv)