mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/engine/schedule/_pipeline_schedule_v2.py code style (#3275)
parent
196d4696d0
commit
1168b50e33
|
@ -1,11 +1,12 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Tuple, Iterable
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
from colossalai import engine
|
||||
import colossalai.communication.p2p_v2 as comm
|
||||
import torch.cuda
|
||||
|
||||
import colossalai.communication.p2p_v2 as comm
|
||||
from colossalai import engine
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
@ -35,7 +36,7 @@ def pack_return_tensors(return_tensors):
|
|||
class PipelineScheduleV2(PipelineSchedule):
|
||||
"""Derived class of PipelineSchedule, the only difference is that
|
||||
forward_backward_step is reconstructed with p2p_v2
|
||||
|
||||
|
||||
Args:
|
||||
num_microbatches (int): The number of microbatches.
|
||||
data_process_func (Callable, optional):
|
||||
|
@ -43,9 +44,9 @@ class PipelineScheduleV2(PipelineSchedule):
|
|||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||
scatter_gather_tensors (bool, optional):
|
||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(stage_output, dataloader_output):
|
||||
output1, output2 = stage_output
|
||||
|
|
Loading…
Reference in New Issue