Browse Source

[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)

pull/2829/head
Michelle 2 years ago committed by GitHub
parent
commit
c008d4ad0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 14
      colossalai/engine/schedule/_pipeline_schedule.py

14
colossalai/engine/schedule/_pipeline_schedule.py

@ -4,8 +4,9 @@
import inspect import inspect
from typing import Callable, List, Tuple, Union from typing import Callable, List, Tuple, Union
import colossalai.communication as comm
import torch.cuda import torch.cuda
import colossalai.communication as comm
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
@ -72,9 +73,9 @@ class PipelineSchedule(BaseSchedule):
tensor_shape (torch.Size, optional): Specified shape in pipeline communication. tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional): scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
Example: Example:
# this shows an example of customized data_process_func # this shows an example of customized data_process_func
def data_process_func(stage_output, dataloader_output): def data_process_func(stage_output, dataloader_output):
output1, output2 = stage_output output1, output2 = stage_output
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
@ -229,7 +231,7 @@ class PipelineSchedule(BaseSchedule):
return data, label return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
@ -266,7 +268,7 @@ class PipelineSchedule(BaseSchedule):
return output_obj return output_obj
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the """Backward step through the passed-in output tensor. If it is the last stage, the
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage). Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users. This is a helper function and can be ignored by users.
@ -511,7 +513,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
return_tensors, return_tensors,
return_output_label=True, return_output_label=True,
accum_loss=None): accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.

Loading…
Cancel
Save