mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/engine/schedule/_pipeline_schedule.py code style (#2744)
parent
58abde2857
commit
c008d4ad0c
|
@ -4,8 +4,9 @@
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
import colossalai.communication as comm
|
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
|
||||||
|
import colossalai.communication as comm
|
||||||
from colossalai.amp.naive_amp import NaiveAMPModel
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
@ -72,9 +73,9 @@ class PipelineSchedule(BaseSchedule):
|
||||||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||||
scatter_gather_tensors (bool, optional):
|
scatter_gather_tensors (bool, optional):
|
||||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
# this shows an example of customized data_process_func
|
# this shows an example of customized data_process_func
|
||||||
def data_process_func(stage_output, dataloader_output):
|
def data_process_func(stage_output, dataloader_output):
|
||||||
output1, output2 = stage_output
|
output1, output2 = stage_output
|
||||||
|
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
|
|
||||||
# TODO: remove this after testing new zero with pipeline parallelism
|
# TODO: remove this after testing new zero with pipeline parallelism
|
||||||
model = engine.model
|
model = engine.model
|
||||||
if isinstance(model, NaiveAMPModel):
|
if isinstance(model, NaiveAMPModel):
|
||||||
|
@ -229,7 +231,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
return data, label
|
return data, label
|
||||||
|
|
||||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
|
||||||
|
@ -266,7 +268,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
return output_obj
|
return output_obj
|
||||||
|
|
||||||
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
|
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
|
||||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||||
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||||
Returns the gradients with respect to the input tensor (None if first stage).
|
Returns the gradients with respect to the input tensor (None if first stage).
|
||||||
This is a helper function and can be ignored by users.
|
This is a helper function and can be ignored by users.
|
||||||
|
@ -511,7 +513,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
return_tensors,
|
return_tensors,
|
||||||
return_output_label=True,
|
return_output_label=True,
|
||||||
accum_loss=None):
|
accum_loss=None):
|
||||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||||
Returns output tensor. This is a helper function and can be ignored by users.
|
Returns output tensor. This is a helper function and can be ignored by users.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue