mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] fixed pipeline bug for sequence parallel (#1943)
parent
e52f9d9109
commit
c6ea65011f
|
@ -35,6 +35,17 @@ def parse_args():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_data_process_func(stage_output, micro_batch_data):
|
||||||
|
tokens, types, sentence_order, loss_mask, lm_labels, padding_mask = micro_batch_data
|
||||||
|
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||||
|
data = (tokens, padding_mask, types, lm_labels)
|
||||||
|
label = (loss_mask, sentence_order)
|
||||||
|
else:
|
||||||
|
data = (stage_output, padding_mask, types, lm_labels)
|
||||||
|
label = (loss_mask, sentence_order)
|
||||||
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# initialize
|
# initialize
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
@ -155,6 +166,7 @@ def main():
|
||||||
if use_pipeline:
|
if use_pipeline:
|
||||||
train_data_iter = SequenceParallelDataIterator(trainloader)
|
train_data_iter = SequenceParallelDataIterator(trainloader)
|
||||||
valid_data_iter = SequenceParallelDataIterator(validloader)
|
valid_data_iter = SequenceParallelDataIterator(validloader)
|
||||||
|
engine.schedule.data_process_func = pipeline_data_process_func
|
||||||
|
|
||||||
logger.info("start training")
|
logger.info("start training")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue