fix all-gather overlap the model_checkpoint is 0

pull/436/head
yingtongxiong 2023-11-01 12:31:52 +08:00
parent b3def4c162
commit 10b5056e1e
3 changed files with 10 additions and 6 deletions

View File

@ -1,7 +1,7 @@
JOB_NAME = "7b_train"
DO_ALERT = False
SEQ_LEN = 4096
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
@ -49,9 +49,9 @@ VALID_FOLDER = "/path/to/dataset"
data = dict(
seq_len=SEQ_LEN,
# micro_num means the number of micro_batch contained in one gradient update
micro_num=1,
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=1,
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
@ -163,7 +163,7 @@ pipeline parallel (dict):
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=dict(size=8, sp="intern", intern_overlap=True, reduce_scatter_overlap=True),
tensor=dict(size=4, sp="intern", intern_overlap=True, reduce_scatter_overlap=True),
pipeline=dict(size=1, interleaved_overlap=True),
)

View File

@ -315,7 +315,7 @@ class FSTPOverlapHandler:
# 1. register post_backward_hook @head module to prefetch for the last block's last module
# 2. register pre_backward_hook @fstp_module to wait handle for current module and to prefetch for next module
# 3. register post_backward_hook @fstp_module to release resource
if self.model_checkpoint is False:
if not self.model_checkpoint:
for head in self.head:
head.register_full_backward_hook(_post_backward_hook_for_head)

View File

@ -407,6 +407,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
tgs_list = []
tflops_list = []
tflops_list_2 = []
loss_list = []
@llm_timeout(func_name="record_current_batch_training_metrics")
@ -599,11 +600,12 @@ def record_current_batch_training_metrics(
step_count=batch_count,
cur_step_loss=loss.item(),
)
loss_list.append(loss.item())
if batch_count >= 5:
tgs_list.append(tgs_origin)
tflops_list.append(tflops)
tflops_list_2.append(tflops_2)
if batch_count == gpc.config.data.total_steps - 1:
print(tgs_list, flush=True)
avg_tgs = sum(tgs_list) / len(tgs_list)
@ -625,3 +627,5 @@ def record_current_batch_training_metrics(
if abs(tf - avg_tflops_2) > 10:
tflops_list_2.remove(tf)
print(f"avg_tflops_2: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)
print("loss: ", loss_list, flush=True)