[NFC] polish colossalai/engine/schedule/_pipeline_schedule_v2.py code style (#3275)

pull/3313/head
Zirui Zhu 2023-03-28 14:31:38 +08:00 committed by binmakeswell
parent 196d4696d0
commit 1168b50e33
1 changed files with 7 additions and 6 deletions

View File

@ -1,11 +1,12 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Tuple, Iterable
from typing import Iterable, Tuple
from colossalai import engine
import colossalai.communication.p2p_v2 as comm
import torch.cuda
import colossalai.communication.p2p_v2 as comm
from colossalai import engine
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils.cuda import get_current_device
@ -35,7 +36,7 @@ def pack_return_tensors(return_tensors):
class PipelineScheduleV2(PipelineSchedule):
"""Derived class of PipelineSchedule, the only difference is that
forward_backward_step is reconstructed with p2p_v2
Args:
num_microbatches (int): The number of microbatches.
data_process_func (Callable, optional):
@ -43,9 +44,9 @@ class PipelineScheduleV2(PipelineSchedule):
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
Example:
# this shows an example of customized data_process_func
def data_process_func(stage_output, dataloader_output):
output1, output2 = stage_output