From c008d4ad0c755586d5e7cb74b483ad707c78967c Mon Sep 17 00:00:00 2001 From: Michelle <97082656+MichelleMa8@users.noreply.github.com> Date: Mon, 20 Feb 2023 10:38:40 +0800 Subject: [PATCH] [NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744) --- colossalai/engine/schedule/_pipeline_schedule.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 97571fa02..712ae8242 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -4,8 +4,9 @@ import inspect from typing import Callable, List, Tuple, Union -import colossalai.communication as comm import torch.cuda + +import colossalai.communication as comm from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc @@ -72,9 +73,9 @@ class PipelineSchedule(BaseSchedule): tensor_shape (torch.Size, optional): Specified shape in pipeline communication. scatter_gather_tensors (bool, optional): If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. - + Example: - + # this shows an example of customized data_process_func def data_process_func(stage_output, dataloader_output): output1, output2 = stage_output @@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule): def pre_processing(self, engine): from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + # TODO: remove this after testing new zero with pipeline parallelism model = engine.model if isinstance(model, NaiveAMPModel): @@ -229,7 +231,7 @@ class PipelineSchedule(BaseSchedule): return data, label def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): - """Forward step for passed-in model. If it is the first stage, the input tensor + """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. @@ -266,7 +268,7 @@ class PipelineSchedule(BaseSchedule): return output_obj def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): - """Backward step through the passed-in output tensor. If it is the last stage, the + """Backward step through the passed-in output tensor. If it is the last stage, the output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. Returns the gradients with respect to the input tensor (None if first stage). This is a helper function and can be ignored by users. @@ -511,7 +513,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): return_tensors, return_output_label=True, accum_loss=None): - """Forward step for passed-in model. If it is the first stage, the input tensor + """Forward step for passed-in model. If it is the first stage, the input tensor is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users.