mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] fixed pipeline bug for sequence parallel (#1943)
parent
e52f9d9109
commit
c6ea65011f
|
@ -35,6 +35,17 @@ def parse_args():
|
|||
return parser.parse_args()
|
||||
|
||||
|
||||
def pipeline_data_process_func(stage_output, micro_batch_data):
|
||||
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
data = (tokens, padding_mask, types, lm_labels)
|
||||
label = (loss_mask, sentence_order)
|
||||
else:
|
||||
data = (stage_output, padding_mask, types, lm_labels)
|
||||
label = (loss_mask, sentence_order)
|
||||
return data, label
|
||||
|
||||
|
||||
def main():
|
||||
# initialize
|
||||
args = parse_args()
|
||||
|
@ -155,6 +166,7 @@ def main():
|
|||
if use_pipeline:
|
||||
train_data_iter = SequenceParallelDataIterator(trainloader)
|
||||
valid_data_iter = SequenceParallelDataIterator(validloader)
|
||||
engine.schedule.data_process_func = pipeline_data_process_func
|
||||
|
||||
logger.info("start training")
|
||||
|
||||
|
|
Loading…
Reference in New Issue