diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/engine/schedule/_pipeline_schedule_v2.py index 50a87aafa..28c58bd82 100644 --- a/colossalai/engine/schedule/_pipeline_schedule_v2.py +++ b/colossalai/engine/schedule/_pipeline_schedule_v2.py @@ -1,11 +1,12 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Tuple, Iterable +from typing import Iterable, Tuple -from colossalai import engine -import colossalai.communication.p2p_v2 as comm import torch.cuda + +import colossalai.communication.p2p_v2 as comm +from colossalai import engine from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils.cuda import get_current_device @@ -35,7 +36,7 @@ def pack_return_tensors(return_tensors): class PipelineScheduleV2(PipelineSchedule): """Derived class of PipelineSchedule, the only difference is that forward_backward_step is reconstructed with p2p_v2 - + Args: num_microbatches (int): The number of microbatches. data_process_func (Callable, optional): @@ -43,9 +44,9 @@ class PipelineScheduleV2(PipelineSchedule): tensor_shape (torch.Size, optional): Specified shape in pipeline communication. scatter_gather_tensors (bool, optional): If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - + Example: - + # this shows an example of customized data_process_func def data_process_func(stage_output, dataloader_output): output1, output2 = stage_output