|
|
@ -4,8 +4,9 @@ |
|
|
|
import inspect |
|
|
|
import inspect |
|
|
|
from typing import Callable, List, Tuple, Union |
|
|
|
from typing import Callable, List, Tuple, Union |
|
|
|
|
|
|
|
|
|
|
|
import colossalai.communication as comm |
|
|
|
|
|
|
|
import torch.cuda |
|
|
|
import torch.cuda |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import colossalai.communication as comm |
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel |
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel |
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
from colossalai.core import global_context as gpc |
|
|
@ -72,9 +73,9 @@ class PipelineSchedule(BaseSchedule): |
|
|
|
tensor_shape (torch.Size, optional): Specified shape in pipeline communication. |
|
|
|
tensor_shape (torch.Size, optional): Specified shape in pipeline communication. |
|
|
|
scatter_gather_tensors (bool, optional): |
|
|
|
scatter_gather_tensors (bool, optional): |
|
|
|
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. |
|
|
|
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. |
|
|
|
|
|
|
|
|
|
|
|
Example: |
|
|
|
Example: |
|
|
|
|
|
|
|
|
|
|
|
# this shows an example of customized data_process_func |
|
|
|
# this shows an example of customized data_process_func |
|
|
|
def data_process_func(stage_output, dataloader_output): |
|
|
|
def data_process_func(stage_output, dataloader_output): |
|
|
|
output1, output2 = stage_output |
|
|
|
output1, output2 = stage_output |
|
|
@ -157,6 +158,7 @@ class PipelineSchedule(BaseSchedule): |
|
|
|
|
|
|
|
|
|
|
|
def pre_processing(self, engine): |
|
|
|
def pre_processing(self, engine): |
|
|
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 |
|
|
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 |
|
|
|
|
|
|
|
|
|
|
|
# TODO: remove this after testing new zero with pipeline parallelism |
|
|
|
# TODO: remove this after testing new zero with pipeline parallelism |
|
|
|
model = engine.model |
|
|
|
model = engine.model |
|
|
|
if isinstance(model, NaiveAMPModel): |
|
|
|
if isinstance(model, NaiveAMPModel): |
|
|
@ -229,7 +231,7 @@ class PipelineSchedule(BaseSchedule): |
|
|
|
return data, label |
|
|
|
return data, label |
|
|
|
|
|
|
|
|
|
|
|
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): |
|
|
|
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): |
|
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor |
|
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor |
|
|
|
is obtained from data_iterator, otherwise the passed-in input_obj is used. |
|
|
|
is obtained from data_iterator, otherwise the passed-in input_obj is used. |
|
|
|
Returns output tensor. This is a helper function and can be ignored by users. |
|
|
|
Returns output tensor. This is a helper function and can be ignored by users. |
|
|
|
|
|
|
|
|
|
|
@ -266,7 +268,7 @@ class PipelineSchedule(BaseSchedule): |
|
|
|
return output_obj |
|
|
|
return output_obj |
|
|
|
|
|
|
|
|
|
|
|
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): |
|
|
|
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): |
|
|
|
"""Backward step through the passed-in output tensor. If it is the last stage, the |
|
|
|
"""Backward step through the passed-in output tensor. If it is the last stage, the |
|
|
|
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. |
|
|
|
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. |
|
|
|
Returns the gradients with respect to the input tensor (None if first stage). |
|
|
|
Returns the gradients with respect to the input tensor (None if first stage). |
|
|
|
This is a helper function and can be ignored by users. |
|
|
|
This is a helper function and can be ignored by users. |
|
|
@ -511,7 +513,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): |
|
|
|
return_tensors, |
|
|
|
return_tensors, |
|
|
|
return_output_label=True, |
|
|
|
return_output_label=True, |
|
|
|
accum_loss=None): |
|
|
|
accum_loss=None): |
|
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor |
|
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor |
|
|
|
is obtained from data_iterator, otherwise the passed-in input_obj is used. |
|
|
|
is obtained from data_iterator, otherwise the passed-in input_obj is used. |
|
|
|
Returns output tensor. This is a helper function and can be ignored by users. |
|
|
|
Returns output tensor. This is a helper function and can be ignored by users. |
|
|
|
|
|
|
|
|
|
|
|