diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 98bceeb..36f9ac1 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -5,7 +5,7 @@ SEQ_LEN = 4096 HIDDEN_SIZE = 8192 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 -NUM_LAYER = 8 +NUM_LAYER = 4 VOCAB_SIZE = 103168 MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" @@ -57,7 +57,7 @@ data = dict( # defaults to 0, means disable evaluate valid_every=50, pack_sample_into_one=False, - total_steps=50000, + total_steps=20, skip_batches="", rampup_batch_size="", # Datasets with less than 50 rows will be discarded @@ -161,10 +161,11 @@ pipeline parallel (dict): sequence parallel (bool): enable/disable sequence parallel, defaults to False. """ parallel = dict( - zero1=dict(size=1, fsdp=False), - tensor=dict(size=8, mode='fstp'), # the mode should be 'origin_tp' or 'fstp'. if the mode is 'fstp', the sequence_parallel should be True + zero1=dict(size=-1, fsdp=False), + tensor=dict(size=8, mode="fstp"), pipeline=dict(size=1, interleaved_overlap=True), sequence_parallel=True, + block_0_full_weight=True, ) cudnn_deterministic = False diff --git a/internlm/model/linear.py b/internlm/model/linear.py index 8a17c71..8e19ab6 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -559,21 +559,21 @@ class CoarseGrainedFSTPAllGatherSyncHandler: def _pre_backward_hook_for_block(block: nn.Module, grad_output): block_index = self.block_to_index[block] - if block_index == gpc.config.NUM_LAYER - 1: - # all gather weight for the last block - fsdp_modules = self.index_to_fsdp_modules[block_index] - for module in fsdp_modules: - total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handle.wait() - self.FSTP_global_weights[module] = total_weight - else: - # wait handle for current block - handles = self.block_handles[block] - for handle in handles: - handle.wait() + # if block_index == gpc.config.NUM_LAYER - 1: + # # all gather weight for the last block + # fsdp_modules = self.index_to_fsdp_modules[block_index] + # for module in fsdp_modules: + # total_weight, weight_handle = all_gather_raw(module.weight, self.process_group, async_op=True) + # weight_handle.wait() + # self.FSTP_global_weights[module] = total_weight + # else: + # # wait handle for current block + # handles = self.block_handles[block] + # for handle in handles: + # handle.wait() # start the all-gather for next block - if block_index - 1 >= 0: + if block_index - 1 > 0: self._all_gather_block_weight(block_index - 1) def _post_backward_hook_for_block(block: nn.Module, grad_input, grad_output): @@ -588,36 +588,41 @@ class CoarseGrainedFSTPAllGatherSyncHandler: block_index = self.module_to_index[module] name_index = self.module_name_index[module] if block_index != 0: - if name_index == 4: - total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) - weight_handler.wait() - self.FSTP_global_weights[module] = total_weight + # if name_index == 4: + # total_weight, weight_handler = all_gather_raw(module.weight, self.process_group, async_op=True) + # weight_handler.wait() + # self.FSTP_global_weights[module] = total_weight - # start the all-gather for next module - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler - else: + # # start the all-gather for next module + # next_module = self.block_module[block_index][name_index - 1] + # self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + # next_module.weight, self.process_group, async_op=True + # ) + # self.FSTP_global_handle[next_module] = weights_handler + # else: + # handler = self.FSTP_global_handle[module] + # handler.wait() + # if name_index != 0: + # next_module = self.block_module[block_index][name_index - 1] + # self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( + # next_module.weight, self.process_group, async_op=True + # ) + # self.FSTP_global_handle[next_module] = weights_handler + if module in self.FSTP_global_handle: handler = self.FSTP_global_handle[module] handler.wait() - if name_index != 0: - next_module = self.block_module[block_index][name_index - 1] - self.FSTP_global_weights[next_module], weights_handler = all_gather_raw( - next_module.weight, self.process_group, async_op=True - ) - self.FSTP_global_handle[next_module] = weights_handler def _post_backward_hook_for_module(module, grad_input, grad_output): if module in self.FSTP_global_weights: del self.FSTP_global_weights[module] + if module in self.FSTP_global_handle: + del self.FSTP_global_handle[module] - # for block in self.FSTP_blocks: - # block.register_forward_pre_hook(_pre_forward_hook_for_block) - # block.register_forward_hook(_post_forward_hook_for_block) - # block.register_full_backward_pre_hook(_pre_backward_hook_for_block) - # block.register_full_backward_hook(_post_backward_hook_for_block) + for block in self.FSTP_blocks: + # block.register_forward_pre_hook(_pre_forward_hook_for_block) + # block.register_forward_hook(_post_forward_hook_for_block) + block.register_full_backward_pre_hook(_pre_backward_hook_for_block) + # block.register_full_backward_hook(_post_backward_hook_for_block) for out_proj in self.FSTP_outs: out_proj.register_forward_pre_hook(_pre_forward_hook_for_out_proj)