[tutorial] fixed pipeline bug for sequence parallel (#1943)

pull/1944/head
Frank Lee 2022-11-14 18:06:57 +08:00 committed by GitHub
parent e52f9d9109
commit c6ea65011f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 0 deletions

View File

@ -35,6 +35,17 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def pipeline_data_process_func(stage_output, micro_batch_data):
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
if gpc.is_first_rank(ParallelMode.PIPELINE):
data = (tokens, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
else:
data = (stage_output, padding_mask, types, lm_labels)
label = (loss_mask, sentence_order)
return data, label
def main(): def main():
# initialize # initialize
args = parse_args() args = parse_args()
@ -155,6 +166,7 @@ def main():
if use_pipeline: if use_pipeline:
train_data_iter = SequenceParallelDataIterator(trainloader) train_data_iter = SequenceParallelDataIterator(trainloader)
valid_data_iter = SequenceParallelDataIterator(validloader) valid_data_iter = SequenceParallelDataIterator(validloader)
engine.schedule.data_process_func = pipeline_data_process_func
logger.info("start training") logger.info("start training")