mirror of https://github.com/InternLM/InternLM
fix all-gather overlap the model_checkpoint is 0
parent
b3def4c162
commit
10b5056e1e
|
@ -1,7 +1,7 @@
|
|||
JOB_NAME = "7b_train"
|
||||
DO_ALERT = False
|
||||
|
||||
SEQ_LEN = 4096
|
||||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
|
@ -49,9 +49,9 @@ VALID_FOLDER = "/path/to/dataset"
|
|||
data = dict(
|
||||
seq_len=SEQ_LEN,
|
||||
# micro_num means the number of micro_batch contained in one gradient update
|
||||
micro_num=1,
|
||||
micro_num=4,
|
||||
# packed_length = micro_bsz * SEQ_LEN
|
||||
micro_bsz=1,
|
||||
micro_bsz=2,
|
||||
# defaults to the value of micro_num
|
||||
valid_micro_num=4,
|
||||
# defaults to 0, means disable evaluate
|
||||
|
@ -163,7 +163,7 @@ pipeline parallel (dict):
|
|||
"""
|
||||
parallel = dict(
|
||||
zero1=dict(size=-1, fsdp=False),
|
||||
tensor=dict(size=8, sp="intern", intern_overlap=True, reduce_scatter_overlap=True),
|
||||
tensor=dict(size=4, sp="intern", intern_overlap=True, reduce_scatter_overlap=True),
|
||||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
)
|
||||
|
||||
|
|
|
@ -315,7 +315,7 @@ class FSTPOverlapHandler:
|
|||
# 1. register post_backward_hook @head module to prefetch for the last block's last module
|
||||
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
|
||||
# 3. register post_backward_hook @fstp_module to release resource
|
||||
if self.model_checkpoint is False:
|
||||
if not self.model_checkpoint:
|
||||
for head in self.head:
|
||||
head.register_full_backward_hook(_post_backward_hook_for_head)
|
||||
|
||||
|
|
|
@ -407,6 +407,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
|||
tgs_list = []
|
||||
tflops_list = []
|
||||
tflops_list_2 = []
|
||||
loss_list = []
|
||||
|
||||
|
||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||
|
@ -599,11 +600,12 @@ def record_current_batch_training_metrics(
|
|||
step_count=batch_count,
|
||||
cur_step_loss=loss.item(),
|
||||
)
|
||||
|
||||
loss_list.append(loss.item())
|
||||
if batch_count >= 5:
|
||||
tgs_list.append(tgs_origin)
|
||||
tflops_list.append(tflops)
|
||||
tflops_list_2.append(tflops_2)
|
||||
|
||||
if batch_count == gpc.config.data.total_steps - 1:
|
||||
print(tgs_list, flush=True)
|
||||
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||
|
@ -625,3 +627,5 @@ def record_current_batch_training_metrics(
|
|||
if abs(tf - avg_tflops_2) > 10:
|
||||
tflops_list_2.remove(tf)
|
||||
print(f"avg_tflops_2: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)
|
||||
|
||||
print("loss: ", loss_list, flush=True)
|
||||
|
|
Loading…
Reference in New Issue