mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)
parent
58abde2857
commit
c008d4ad0c
|
@ -4,8 +4,9 @@
|
|||
import inspect
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import colossalai.communication as comm
|
||||
import torch.cuda
|
||||
|
||||
import colossalai.communication as comm
|
||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -72,9 +73,9 @@ class PipelineSchedule(BaseSchedule):
|
|||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||
scatter_gather_tensors (bool, optional):
|
||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(stage_output, dataloader_output):
|
||||
output1, output2 = stage_output
|
||||
|
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
def pre_processing(self, engine):
|
||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
model = engine.model
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
|
@ -229,7 +231,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
return data, label
|
||||
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
||||
|
@ -266,7 +268,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
return output_obj
|
||||
|
||||
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
Returns the gradients with respect to the input tensor (None if first stage).
|
||||
This is a helper function and can be ignored by users.
|
||||
|
@ -511,7 +513,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
return_tensors,
|
||||
return_output_label=True,
|
||||
accum_loss=None):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
||||
|
|
Loading…
Reference in New Issue