From 10b5056e1ebfe540f1008c97f4b3bcdafe8b22da Mon Sep 17 00:00:00 2001 From: yingtongxiong <974106207@qq.com> Date: Wed, 1 Nov 2023 12:31:52 +0800 Subject: [PATCH] fix all-gather overlap the model_checkpoint is 0 --- configs/7B_sft.py | 8 ++++---- internlm/model/overlap_handler.py | 2 +- internlm/train/training_internlm.py | 6 +++++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index b34a838..9928508 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -1,7 +1,7 @@ JOB_NAME = "7b_train" DO_ALERT = False -SEQ_LEN = 4096 +SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 MLP_RATIO = 8 / 3 @@ -49,9 +49,9 @@ VALID_FOLDER = "/path/to/dataset" data = dict( seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update - micro_num=1, + micro_num=4, # packed_length = micro_bsz * SEQ_LEN - micro_bsz=1, + micro_bsz=2, # defaults to the value of micro_num valid_micro_num=4, # defaults to 0, means disable evaluate @@ -163,7 +163,7 @@ pipeline parallel (dict): """ parallel = dict( zero1=dict(size=-1, fsdp=False), - tensor=dict(size=8, sp="intern", intern_overlap=True, reduce_scatter_overlap=True), + tensor=dict(size=4, sp="intern", intern_overlap=True, reduce_scatter_overlap=True), pipeline=dict(size=1, interleaved_overlap=True), ) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index 418c4aa..db81150 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -315,7 +315,7 @@ class FSTPOverlapHandler: # 1. register post_backward_hook @head module to prefetch for the last block's last module # 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module # 3. register post_backward_hook @fstp_module to release resource - if self.model_checkpoint is False: + if not self.model_checkpoint: for head in self.head: head.register_full_backward_hook(_post_backward_hook_for_head) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 2b80692..2b5a1bb 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -407,6 +407,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): tgs_list = [] tflops_list = [] tflops_list_2 = [] +loss_list = [] @llm_timeout(func_name="record_current_batch_training_metrics") @@ -599,11 +600,12 @@ def record_current_batch_training_metrics( step_count=batch_count, cur_step_loss=loss.item(), ) - + loss_list.append(loss.item()) if batch_count >= 5: tgs_list.append(tgs_origin) tflops_list.append(tflops) tflops_list_2.append(tflops_2) + if batch_count == gpc.config.data.total_steps - 1: print(tgs_list, flush=True) avg_tgs = sum(tgs_list) / len(tgs_list) @@ -625,3 +627,5 @@ def record_current_batch_training_metrics( if abs(tf - avg_tflops_2) > 10: tflops_list_2.remove(tf) print(f"avg_tflops_2: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True) + + print("loss: ", loss_list, flush=True)