2021-10-28 16:21:23 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
|
2021-12-30 07:56:46 +00:00
|
|
|
import inspect
|
2022-01-21 07:46:02 +00:00
|
|
|
from typing import Callable, List, Tuple, Union
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-01-21 07:46:02 +00:00
|
|
|
import torch.cuda
|
2023-02-20 02:38:40 +00:00
|
|
|
|
|
|
|
import colossalai.communication as comm
|
2022-01-21 07:46:02 +00:00
|
|
|
from colossalai.amp.naive_amp import NaiveAMPModel
|
2021-10-28 16:21:23 +00:00
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
2022-01-21 07:46:02 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
2021-12-30 07:56:46 +00:00
|
|
|
from colossalai.utils.cuda import get_current_device
|
2022-01-21 07:46:02 +00:00
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
from ._base_schedule import BaseSchedule
|
|
|
|
|
2022-04-19 02:13:08 +00:00
|
|
|
|
2022-04-07 07:54:14 +00:00
|
|
|
def get_tensor_shape():
|
|
|
|
if hasattr(gpc.config, 'TENSOR_SHAPE'):
|
|
|
|
return gpc.config.TENSOR_SHAPE
|
|
|
|
|
|
|
|
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
|
|
|
return None
|
|
|
|
|
2022-04-19 02:13:08 +00:00
|
|
|
if hasattr(gpc.config, 'SEQ_LENGTH') and hasattr(gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(
|
|
|
|
gpc.config, 'GLOBAL_BATCH_SIZE') and hasattr(gpc.config, 'HIDDEN_SIZE'):
|
2022-04-07 07:54:14 +00:00
|
|
|
if gpc.is_initialized(ParallelMode.DATA):
|
|
|
|
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
|
|
|
else:
|
|
|
|
dp_size = 1
|
|
|
|
if gpc.is_initialized(ParallelMode.SEQUENCE):
|
|
|
|
seq_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
|
|
|
else:
|
|
|
|
seq_size = 1
|
|
|
|
|
|
|
|
tensor_shape = (gpc.config.SEQ_LENGTH // seq_size,
|
2022-04-19 02:13:08 +00:00
|
|
|
gpc.config.GLOBAL_BATCH_SIZE // dp_size // gpc.config.NUM_MICRO_BATCHES, gpc.config.HIDDEN_SIZE)
|
2022-04-07 07:54:14 +00:00
|
|
|
return tensor_shape
|
|
|
|
else:
|
|
|
|
return None
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-04-19 02:13:08 +00:00
|
|
|
|
2022-01-17 07:57:47 +00:00
|
|
|
def pack_return_tensors(return_tensors):
|
|
|
|
output, label = tuple(zip(*return_tensors))
|
|
|
|
if isinstance(output[0], torch.Tensor):
|
|
|
|
output = torch.cat(output, dim=0)
|
|
|
|
elif isinstance(output[0], (list, tuple)):
|
|
|
|
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
2022-01-17 07:57:47 +00:00
|
|
|
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
|
|
|
|
if isinstance(label[0], torch.Tensor):
|
|
|
|
label = torch.cat(label, dim=0)
|
|
|
|
else:
|
|
|
|
merged_label = {k: [] for k in label[0].keys()}
|
|
|
|
for d in label:
|
|
|
|
for k, v in d.items():
|
|
|
|
merged_label[k].append(v)
|
|
|
|
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
|
|
|
|
return output, label
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
class PipelineSchedule(BaseSchedule):
|
|
|
|
"""A helper schedule class for pipeline parallelism running environment.
|
|
|
|
It uses non-interleaved 1F1B strategy. Other properties are similar as
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
:class:`NonPipelineSchedule`.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
num_microbatches (int): The number of microbatches.
|
2022-06-15 02:41:28 +00:00
|
|
|
data_process_func (Callable, optional):
|
2022-03-25 05:02:39 +00:00
|
|
|
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
|
|
|
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
|
|
|
scatter_gather_tensors (bool, optional):
|
|
|
|
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
2023-02-20 02:38:40 +00:00
|
|
|
|
2022-06-15 02:41:28 +00:00
|
|
|
Example:
|
2023-02-20 02:38:40 +00:00
|
|
|
|
2022-06-15 02:41:28 +00:00
|
|
|
# this shows an example of customized data_process_func
|
|
|
|
def data_process_func(stage_output, dataloader_output):
|
|
|
|
output1, output2 = stage_output
|
|
|
|
item1, item2, item3 = dataloader_output
|
|
|
|
|
|
|
|
# assume item2 is not needed
|
|
|
|
data = (output1, output2, item1)
|
|
|
|
label = item3
|
|
|
|
return data, label
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
num_microbatches,
|
2022-06-15 02:41:28 +00:00
|
|
|
data_process_func: Callable = None,
|
2022-01-07 05:22:22 +00:00
|
|
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
|
|
|
scatter_gather_tensors: bool = False):
|
2022-06-15 02:41:28 +00:00
|
|
|
|
|
|
|
# we need to make sure that the signature of the data_process_func is valid
|
|
|
|
if data_process_func:
|
|
|
|
sig = inspect.signature(data_process_func)
|
|
|
|
assert len(sig.parameters) == 2, \
|
|
|
|
'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \
|
|
|
|
'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \
|
|
|
|
'i.e. data_process_func(stage_output, dataloader_output).'
|
|
|
|
|
|
|
|
super().__init__(data_process_func=data_process_func)
|
2022-04-26 02:00:18 +00:00
|
|
|
|
|
|
|
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
|
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
self.num_microbatches = num_microbatches
|
2021-12-20 15:26:19 +00:00
|
|
|
self.dtype = torch.float
|
2022-06-13 06:57:25 +00:00
|
|
|
assert not isinstance(tensor_shape,
|
|
|
|
int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
|
|
|
|
if tensor_shape is None:
|
|
|
|
self.tensor_shape = tensor_shape
|
|
|
|
elif isinstance(tensor_shape, torch.Size):
|
|
|
|
self.tensor_shape = tensor_shape
|
|
|
|
else:
|
|
|
|
self.tensor_shape = torch.Size(tensor_shape)
|
2022-01-07 05:22:22 +00:00
|
|
|
self.scatter_gather_tensors = False
|
|
|
|
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
|
|
|
self.scatter_gather_tensors = scatter_gather_tensors
|
2022-01-17 07:57:47 +00:00
|
|
|
self._logger = get_dist_logger()
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-15 02:41:28 +00:00
|
|
|
# cache for the batch data
|
|
|
|
self.batch_data = None
|
|
|
|
|
2021-11-18 11:45:06 +00:00
|
|
|
def load_batch(self, data_iter):
|
2021-12-30 07:56:46 +00:00
|
|
|
# Pipeline schedule just puts data in memory
|
2022-06-15 02:41:28 +00:00
|
|
|
batch_data = super().load_batch(data_iter, to_gpu=False)
|
2021-12-30 07:56:46 +00:00
|
|
|
self.microbatch_offset = 0
|
2022-06-15 02:41:28 +00:00
|
|
|
assert self.batch_size % self.num_microbatches == 0, \
|
2021-10-28 16:21:23 +00:00
|
|
|
"Batch size should divided by the number of microbatches"
|
2022-06-15 02:41:28 +00:00
|
|
|
self.microbatch_size = self.batch_size // self.num_microbatches
|
|
|
|
self.batch_data = batch_data
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2021-12-30 07:56:46 +00:00
|
|
|
def _get_data_slice(self, data, offset):
|
|
|
|
if isinstance(data, torch.Tensor):
|
2022-03-21 08:55:37 +00:00
|
|
|
return data[offset:offset + self.microbatch_size]
|
2022-06-15 02:41:28 +00:00
|
|
|
elif isinstance(data, (list, tuple)):
|
2022-06-16 03:19:48 +00:00
|
|
|
data_dict = {}
|
|
|
|
for element in data:
|
|
|
|
if isinstance(element, dict):
|
|
|
|
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
|
2022-06-17 09:54:15 +00:00
|
|
|
elif data_dict:
|
|
|
|
data_dict['label'] = element[offset:offset + self.microbatch_size]
|
2022-06-16 03:19:48 +00:00
|
|
|
if data_dict:
|
|
|
|
return data_dict
|
2022-06-15 02:41:28 +00:00
|
|
|
return [val[offset:offset + self.microbatch_size] for val in data]
|
2022-03-21 08:55:37 +00:00
|
|
|
elif isinstance(data, dict):
|
2021-12-30 07:56:46 +00:00
|
|
|
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
2022-06-15 02:41:28 +00:00
|
|
|
else:
|
|
|
|
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
def load_micro_batch(self):
|
2023-05-24 01:01:50 +00:00
|
|
|
micro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset)
|
2021-12-30 07:56:46 +00:00
|
|
|
self.microbatch_offset += self.microbatch_size
|
2023-05-24 01:01:50 +00:00
|
|
|
return self._move_to_device(micro_batch_data)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
def pre_processing(self, engine):
|
2023-04-04 05:48:16 +00:00
|
|
|
from colossalai.zero.legacy import ShardedModelV2
|
2023-02-20 02:38:40 +00:00
|
|
|
|
2022-03-01 10:17:01 +00:00
|
|
|
# TODO: remove this after testing new zero with pipeline parallelism
|
2021-12-30 07:56:46 +00:00
|
|
|
model = engine.model
|
2022-05-11 01:23:58 +00:00
|
|
|
if isinstance(model, NaiveAMPModel):
|
2021-12-20 15:26:19 +00:00
|
|
|
self.dtype = torch.half
|
2021-12-30 07:56:46 +00:00
|
|
|
model = model.model
|
2022-05-11 01:23:58 +00:00
|
|
|
if isinstance(model, ShardedModelV2):
|
|
|
|
self.dtype = torch.half
|
|
|
|
model = model.module
|
2022-06-21 06:40:50 +00:00
|
|
|
# sig = inspect.signature(model.forward)
|
|
|
|
# for p in sig.parameters.values():
|
|
|
|
# assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
2021-12-30 07:56:46 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2022-06-15 02:41:28 +00:00
|
|
|
def _call_engine(model, data):
|
|
|
|
if data is not None:
|
|
|
|
if isinstance(data, torch.Tensor):
|
|
|
|
return model(data)
|
|
|
|
elif isinstance(data, (list, tuple)):
|
|
|
|
return model(*data)
|
|
|
|
elif isinstance(data, dict):
|
2022-06-17 09:54:15 +00:00
|
|
|
stage_output = None
|
|
|
|
if 'stage_output' in data:
|
|
|
|
stage_output = data.pop('stage_output')
|
2022-06-21 06:40:50 +00:00
|
|
|
if stage_output is None:
|
|
|
|
return model(**data)
|
|
|
|
elif isinstance(stage_output, torch.Tensor):
|
|
|
|
return model(stage_output, **data)
|
|
|
|
elif isinstance(stage_output, (tuple, list)):
|
|
|
|
return model(*stage_output, **data)
|
|
|
|
else:
|
|
|
|
raise TypeError(
|
|
|
|
f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}"
|
|
|
|
)
|
2021-12-30 07:56:46 +00:00
|
|
|
else:
|
2022-06-15 02:41:28 +00:00
|
|
|
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
|
|
|
|
|
|
|
def _get_actual_forward_func(self, module):
|
|
|
|
if isinstance(module, NaiveAMPModel):
|
|
|
|
sig = inspect.signature(module.model.forward)
|
|
|
|
elif hasattr(module, 'colo_attr'):
|
|
|
|
sig = inspect.signature(module.module.forward)
|
2021-12-30 07:56:46 +00:00
|
|
|
else:
|
2022-06-15 02:41:28 +00:00
|
|
|
sig = inspect.signature(module.forward)
|
|
|
|
return sig
|
|
|
|
|
|
|
|
def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model):
|
|
|
|
if self.data_process_func:
|
|
|
|
# use customized function to get data and label
|
|
|
|
data, label = self.data_process_func(stage_output, micro_batch_data)
|
|
|
|
else:
|
|
|
|
if isinstance(micro_batch_data, (tuple, list)):
|
|
|
|
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
|
|
|
# for the first stage, we use the data from the
|
|
|
|
# dataloader output by default
|
|
|
|
data, label = micro_batch_data
|
|
|
|
else:
|
|
|
|
# for non-first stage, we use the output passed
|
|
|
|
# by the previous as the model input
|
|
|
|
data = stage_output
|
|
|
|
_, label = micro_batch_data
|
|
|
|
elif isinstance(micro_batch_data, dict):
|
|
|
|
data = {}
|
2022-06-17 09:54:15 +00:00
|
|
|
data['stage_output'] = stage_output
|
|
|
|
if 'label' in micro_batch_data:
|
|
|
|
label = micro_batch_data.pop('label')
|
|
|
|
else:
|
|
|
|
label = None
|
|
|
|
load_data = micro_batch_data
|
|
|
|
data.update(load_data)
|
2022-06-15 02:41:28 +00:00
|
|
|
return data, label
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
2023-02-20 02:38:40 +00:00
|
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
2022-06-02 05:48:59 +00:00
|
|
|
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
2021-10-28 16:21:23 +00:00
|
|
|
Returns output tensor. This is a helper function and can be ignored by users.
|
2021-12-13 14:07:01 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
2022-03-25 05:02:39 +00:00
|
|
|
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
|
|
|
return_output_label (bool, optional): Whether returns output labels.
|
|
|
|
accum_loss (optional): Where accumulated loss stores.
|
|
|
|
Returns:
|
2022-06-02 05:48:59 +00:00
|
|
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
2022-06-15 02:41:28 +00:00
|
|
|
micro_batch_data = self.load_micro_batch()
|
|
|
|
|
|
|
|
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model)
|
|
|
|
|
|
|
|
output_obj = self._call_engine(engine.model, data)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
2021-12-30 07:56:46 +00:00
|
|
|
if return_output_label:
|
2022-06-02 05:48:59 +00:00
|
|
|
return_tensors.append((output_obj, label))
|
2021-12-30 07:56:46 +00:00
|
|
|
if accum_loss is not None:
|
2022-06-02 05:48:59 +00:00
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
|
2021-12-30 07:56:46 +00:00
|
|
|
accum_loss.add_(loss_reduced.detach())
|
2021-10-28 16:21:23 +00:00
|
|
|
return loss_reduced
|
|
|
|
else:
|
2022-01-17 07:57:47 +00:00
|
|
|
# forward only, it's useless since backward is not needed
|
2022-06-02 05:48:59 +00:00
|
|
|
return output_obj
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
2022-06-02 05:48:59 +00:00
|
|
|
if isinstance(output_obj, torch.Tensor):
|
|
|
|
self._logger.debug(
|
|
|
|
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
|
|
|
|
)
|
|
|
|
return output_obj
|
|
|
|
|
|
|
|
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
|
2023-02-20 02:38:40 +00:00
|
|
|
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
2021-10-28 16:21:23 +00:00
|
|
|
Returns the gradients with respect to the input tensor (None if first stage).
|
|
|
|
This is a helper function and can be ignored by users.
|
2021-12-13 14:07:01 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
|
|
|
|
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
|
|
|
|
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
|
2022-03-25 05:02:39 +00:00
|
|
|
|
|
|
|
Returns:
|
2022-06-02 05:48:59 +00:00
|
|
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
# Retain the grad on the input_obj.
|
|
|
|
if input_obj is not None:
|
|
|
|
if isinstance(input_obj, torch.Tensor):
|
|
|
|
input_obj.retain_grad()
|
|
|
|
else:
|
|
|
|
for in_tensor in input_obj:
|
|
|
|
if in_tensor is not None:
|
|
|
|
in_tensor.retain_grad()
|
2021-10-28 16:21:23 +00:00
|
|
|
# Backward pass.
|
2022-06-02 05:48:59 +00:00
|
|
|
if output_obj_grad is None:
|
|
|
|
engine.backward(output_obj)
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
else:
|
2022-06-02 05:48:59 +00:00
|
|
|
engine.backward_by_grad(output_obj, output_obj_grad)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
# Collect the grad of the input_obj.
|
|
|
|
input_obj_grad = None
|
|
|
|
if input_obj is not None:
|
|
|
|
if isinstance(input_obj, torch.Tensor):
|
|
|
|
input_obj_grad = input_obj.grad
|
|
|
|
else:
|
|
|
|
input_obj_grad = []
|
|
|
|
for in_tensor in input_obj:
|
|
|
|
input_obj_grad.append(in_tensor.grad)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
return input_obj_grad
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-03-21 08:55:37 +00:00
|
|
|
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
2021-10-28 16:21:23 +00:00
|
|
|
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
|
|
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
2021-11-18 11:45:06 +00:00
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
|
|
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
|
|
|
forward_only (bool, optional):
|
|
|
|
Whether run forward step only. Default is false. If true, no backward will be run.
|
|
|
|
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
|
|
|
return_output_label (bool, optional): If False, the output and label won't be returned.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
2021-10-28 16:21:23 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
assert forward_only or return_loss, \
|
|
|
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
2021-11-18 11:45:06 +00:00
|
|
|
self.load_batch(data_iter)
|
2021-10-28 16:21:23 +00:00
|
|
|
num_warmup_microbatches = \
|
2022-03-21 08:55:37 +00:00
|
|
|
(gpc.get_world_size(ParallelMode.PIPELINE)
|
|
|
|
- gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
|
|
|
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
|
2021-10-28 16:21:23 +00:00
|
|
|
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
|
|
|
|
|
|
|
# Input, output tensors only need to be saved when doing backward passes
|
2022-06-02 05:48:59 +00:00
|
|
|
input_objs = None
|
|
|
|
output_objs = None
|
2021-10-28 16:21:23 +00:00
|
|
|
if not forward_only:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_objs = []
|
|
|
|
output_objs = []
|
2021-10-28 16:21:23 +00:00
|
|
|
return_tensors = []
|
2021-12-30 07:56:46 +00:00
|
|
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
|
|
accum_loss = torch.zeros(1, device=get_current_device())
|
|
|
|
else:
|
|
|
|
accum_loss = None
|
2021-10-28 16:21:23 +00:00
|
|
|
# Used for tensor meta information communication
|
2022-06-02 05:48:59 +00:00
|
|
|
ft_shapes = self.tensor_shape
|
|
|
|
bt_shapes = None
|
2021-12-30 07:56:46 +00:00
|
|
|
fs_checker = self.tensor_shape is None
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Run warmup forward passes.
|
|
|
|
for i in range(num_warmup_microbatches):
|
|
|
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
2022-06-02 05:48:59 +00:00
|
|
|
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
|
|
|
input_obj = comm.recv_forward(ft_shapes,
|
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
|
|
|
output_obj = self._forward_step(engine,
|
|
|
|
input_obj,
|
|
|
|
return_tensors,
|
|
|
|
return_output_label=return_output_label,
|
|
|
|
accum_loss=accum_loss)
|
2021-10-28 16:21:23 +00:00
|
|
|
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
2022-06-02 05:48:59 +00:00
|
|
|
if isinstance(output_obj, torch.Tensor):
|
|
|
|
bt_shapes = output_obj.shape
|
|
|
|
else:
|
|
|
|
bt_shapes = []
|
|
|
|
for out_tensor in output_obj:
|
|
|
|
bt_shapes.append(out_tensor.shape)
|
|
|
|
fs_checker = comm.send_obj_meta(output_obj, fs_checker)
|
|
|
|
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
if not forward_only:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_objs.append(input_obj)
|
|
|
|
output_objs.append(output_obj)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Before running 1F1B, need to receive first forward tensor.
|
|
|
|
# If all microbatches are run in warmup / cooldown phase, then no need to
|
|
|
|
# receive this tensor here.
|
|
|
|
if num_microbatches_remaining > 0:
|
|
|
|
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
2022-06-02 05:48:59 +00:00
|
|
|
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
|
|
|
input_obj = comm.recv_forward(ft_shapes,
|
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Run 1F1B in steady state.
|
|
|
|
for i in range(num_microbatches_remaining):
|
|
|
|
last_iteration = (i == (num_microbatches_remaining - 1))
|
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj = self._forward_step(engine,
|
|
|
|
input_obj,
|
|
|
|
return_tensors,
|
|
|
|
return_output_label=return_output_label,
|
|
|
|
accum_loss=accum_loss)
|
2021-10-28 16:21:23 +00:00
|
|
|
if forward_only:
|
2022-06-02 05:48:59 +00:00
|
|
|
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
if not last_iteration:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj = comm.recv_forward(ft_shapes,
|
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
else:
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj_grad = comm.send_forward_recv_backward(output_obj,
|
|
|
|
bt_shapes,
|
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
# Add input_obj and output_obj to end of list.
|
|
|
|
input_objs.append(input_obj)
|
|
|
|
output_objs.append(output_obj)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
# Pop output_obj and output_obj from the start of the list for
|
2021-10-28 16:21:23 +00:00
|
|
|
# the backward pass.
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj = input_objs.pop(0)
|
|
|
|
output_obj = output_objs.pop(0)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
if last_iteration:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj = None
|
|
|
|
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj = comm.send_backward_recv_forward(input_obj_grad,
|
|
|
|
ft_shapes,
|
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
# Run cooldown backward passes.
|
|
|
|
if not forward_only:
|
|
|
|
for i in range(num_warmup_microbatches):
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj = input_objs.pop(0)
|
|
|
|
output_obj = output_objs.pop(0)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj_grad = comm.recv_backward(bt_shapes,
|
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
if len(return_tensors) > 0:
|
2022-01-17 07:57:47 +00:00
|
|
|
output, label = pack_return_tensors(return_tensors)
|
|
|
|
return output, label, accum_loss
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
2022-01-17 07:57:47 +00:00
|
|
|
return None, None, accum_loss
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
class InterleavedPipelineSchedule(PipelineSchedule):
|
2022-03-21 08:55:37 +00:00
|
|
|
|
2021-12-30 07:56:46 +00:00
|
|
|
def __init__(self,
|
2022-04-26 02:00:18 +00:00
|
|
|
num_microbatches: int,
|
|
|
|
num_model_chunks: int,
|
2022-06-15 02:41:28 +00:00
|
|
|
data_process_func: Callable = None,
|
2022-01-07 05:22:22 +00:00
|
|
|
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
|
|
|
scatter_gather_tensors: bool = False):
|
2021-12-30 07:56:46 +00:00
|
|
|
"""A helper schedule class for pipeline parallelism running environment.
|
|
|
|
It uses interleaved 1F1B strategy. Other properties are similar as
|
|
|
|
:class:`NonPipelineSchedule`.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
num_microbatches (int): The number of microbatches.
|
|
|
|
num_model_chunks (int): The number of model chunks.
|
2022-06-15 02:41:28 +00:00
|
|
|
data_process_func (Callable, optional):
|
2022-03-25 05:02:39 +00:00
|
|
|
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
|
|
|
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
|
|
|
scatter_gather_tensors (bool, optional):
|
|
|
|
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
2021-12-30 07:56:46 +00:00
|
|
|
"""
|
2021-12-20 15:26:19 +00:00
|
|
|
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
|
|
|
|
'num_microbatches must be an integer multiple of pipeline parallel world size'
|
2022-04-26 02:00:18 +00:00
|
|
|
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
|
|
|
|
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
|
2022-03-21 08:55:37 +00:00
|
|
|
super().__init__(num_microbatches,
|
2022-06-15 02:41:28 +00:00
|
|
|
data_process_func=data_process_func,
|
2022-03-21 08:55:37 +00:00
|
|
|
tensor_shape=tensor_shape,
|
|
|
|
scatter_gather_tensors=scatter_gather_tensors)
|
2021-12-20 15:26:19 +00:00
|
|
|
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
|
|
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
2021-12-30 07:56:46 +00:00
|
|
|
self.num_model_chunks = num_model_chunks
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
def pre_processing(self, engine):
|
2022-07-14 05:44:26 +00:00
|
|
|
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
|
2022-04-19 06:03:21 +00:00
|
|
|
if isinstance(engine.model, ShardedModelV2):
|
2022-03-21 08:55:37 +00:00
|
|
|
self.dtype = torch.half
|
|
|
|
elif isinstance(engine.model[0], NaiveAMPModel):
|
2021-12-20 15:26:19 +00:00
|
|
|
self.dtype = torch.half
|
2021-12-30 07:56:46 +00:00
|
|
|
for model in engine.model:
|
|
|
|
if isinstance(model, NaiveAMPModel):
|
|
|
|
model = model.model
|
|
|
|
sig = inspect.signature(model.forward)
|
|
|
|
for p in sig.parameters.values():
|
|
|
|
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
|
|
|
|
|
|
|
def load_batch(self, data_iter):
|
|
|
|
super().load_batch(data_iter)
|
|
|
|
# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
|
|
|
|
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
|
|
|
|
|
|
|
|
def load_micro_batch(self, model_chunk_id):
|
|
|
|
data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id])
|
|
|
|
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
2022-06-15 02:41:28 +00:00
|
|
|
return self._move_to_device(data)
|
2021-12-30 07:56:46 +00:00
|
|
|
|
2022-04-26 02:00:18 +00:00
|
|
|
def _forward_step(self,
|
|
|
|
engine,
|
|
|
|
model_chunk_id,
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj,
|
2022-04-26 02:00:18 +00:00
|
|
|
return_tensors,
|
|
|
|
return_output_label=True,
|
|
|
|
accum_loss=None):
|
2023-02-20 02:38:40 +00:00
|
|
|
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
2022-06-02 05:48:59 +00:00
|
|
|
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
2021-12-20 15:26:19 +00:00
|
|
|
Returns output tensor. This is a helper function and can be ignored by users.
|
2022-03-25 05:02:39 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
|
|
|
model_chunk_id (int): The id of model chunks.
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
2022-03-25 05:02:39 +00:00
|
|
|
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
|
|
|
return_output_label (bool, optional): Whether returns output labels.
|
|
|
|
accum_loss (optional): Where accumulated loss stores.
|
|
|
|
Returns:
|
2022-06-02 05:48:59 +00:00
|
|
|
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
|
2021-12-20 15:26:19 +00:00
|
|
|
"""
|
2022-06-15 02:41:28 +00:00
|
|
|
micro_batch_data = self.load_micro_batch(model_chunk_id)
|
|
|
|
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion,
|
|
|
|
engine.model[model_chunk_id])
|
|
|
|
|
|
|
|
output_obj = self._call_engine(engine.model[model_chunk_id], data)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
if gpc.is_pipeline_last_stage():
|
2021-12-30 07:56:46 +00:00
|
|
|
if return_output_label:
|
2022-06-02 05:48:59 +00:00
|
|
|
return_tensors.append((output_obj, label))
|
2021-12-30 07:56:46 +00:00
|
|
|
if accum_loss is not None:
|
2022-06-02 05:48:59 +00:00
|
|
|
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
|
2021-12-30 07:56:46 +00:00
|
|
|
accum_loss.add_(loss_reduced.detach())
|
2021-12-20 15:26:19 +00:00
|
|
|
return loss_reduced
|
|
|
|
else:
|
2022-01-17 07:57:47 +00:00
|
|
|
# forward only, it's useless since backward is not needed
|
2022-06-02 05:48:59 +00:00
|
|
|
return output_obj
|
2021-12-20 15:26:19 +00:00
|
|
|
else:
|
2022-06-02 05:48:59 +00:00
|
|
|
if isinstance(output_obj, torch.Tensor):
|
|
|
|
self._logger.debug(
|
|
|
|
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
|
|
|
|
)
|
|
|
|
return output_obj
|
2021-12-20 15:26:19 +00:00
|
|
|
|
2021-12-30 07:56:46 +00:00
|
|
|
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
2021-12-20 15:26:19 +00:00
|
|
|
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
|
|
|
communication between pipeline stages as needed.
|
|
|
|
|
2022-03-25 05:02:39 +00:00
|
|
|
Args:
|
|
|
|
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
|
|
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
|
|
|
forward_only (bool, optional):
|
|
|
|
Whether run forward step only. Default is false. If true, no backward will be run.
|
|
|
|
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
|
|
|
return_output_label (bool, optional): If False, the output and label won't be returned.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
|
|
|
The loss would be returned only in the last stage.
|
2022-01-21 02:44:30 +00:00
|
|
|
"""
|
2021-12-20 15:26:19 +00:00
|
|
|
assert forward_only or return_loss, \
|
|
|
|
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
|
|
|
self.load_batch(data_iter)
|
|
|
|
model = engine.model
|
2022-06-02 05:48:59 +00:00
|
|
|
input_objs = [[] for _ in range(len(model))]
|
|
|
|
output_objs = [[] for _ in range(len(model))]
|
2021-12-20 15:26:19 +00:00
|
|
|
return_tensors = []
|
|
|
|
if not forward_only:
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj_grads = [[] for _ in range(len(model))]
|
2021-12-30 07:56:46 +00:00
|
|
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
|
|
accum_loss = torch.zeros(1, device=get_current_device())
|
|
|
|
else:
|
|
|
|
accum_loss = None
|
2021-12-20 15:26:19 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
# Used for obj meta information communication
|
|
|
|
input_obj_shapes = [self.tensor_shape for _ in range(len(model))]
|
|
|
|
output_obj_shapes = [None for _ in range(len(model))]
|
2021-12-30 07:56:46 +00:00
|
|
|
send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
|
|
pipeline_parallel_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
|
|
|
|
|
|
# Compute number of warmup and remaining microbatches.
|
|
|
|
num_model_chunks = len(model)
|
|
|
|
num_microbatches = self.num_microbatches * num_model_chunks
|
|
|
|
all_warmup_microbatches = False
|
|
|
|
if forward_only:
|
|
|
|
num_warmup_microbatches = num_microbatches
|
|
|
|
else:
|
|
|
|
# Run all forward passes and then all backward passes if number of
|
|
|
|
# microbatches is just the number of pipeline stages.
|
|
|
|
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
|
|
|
|
# all workers, followed by more microbatches after depending on
|
|
|
|
# stage ID (more forward passes for earlier stages, later stages can
|
|
|
|
# immediately start with 1F1B).
|
|
|
|
if self.num_microbatches == pipeline_parallel_size:
|
|
|
|
num_warmup_microbatches = num_microbatches
|
|
|
|
all_warmup_microbatches = True
|
|
|
|
else:
|
|
|
|
num_warmup_microbatches = \
|
|
|
|
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
|
2022-03-21 08:55:37 +00:00
|
|
|
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
|
|
|
|
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
|
2021-12-20 15:26:19 +00:00
|
|
|
num_microbatches_remaining = \
|
|
|
|
num_microbatches - num_warmup_microbatches
|
|
|
|
|
|
|
|
def get_model_chunk_id(microbatch_id, forward):
|
|
|
|
"""Helper method to get the model chunk ID given the iteration number."""
|
|
|
|
microbatch_id_in_group = microbatch_id % (pipeline_parallel_size * num_model_chunks)
|
|
|
|
model_chunk_id = microbatch_id_in_group // pipeline_parallel_size
|
|
|
|
if not forward:
|
|
|
|
model_chunk_id = (num_model_chunks - model_chunk_id - 1)
|
|
|
|
return model_chunk_id
|
|
|
|
|
2022-04-26 02:00:18 +00:00
|
|
|
def _forward_step_helper(microbatch_id):
|
2021-12-20 15:26:19 +00:00
|
|
|
"""Helper method to run forward step with model split into chunks
|
|
|
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
|
|
|
forward_step())."""
|
|
|
|
model_chunk_id = get_model_chunk_id(microbatch_id, forward=True)
|
|
|
|
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
|
|
|
|
|
|
|
# forward step
|
|
|
|
if gpc.is_pipeline_first_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
if len(input_objs[model_chunk_id]) == \
|
|
|
|
len(output_objs[model_chunk_id]):
|
|
|
|
input_objs[model_chunk_id].append(None)
|
|
|
|
input_obj = input_objs[model_chunk_id][-1]
|
|
|
|
output_obj = self._forward_step(engine,
|
|
|
|
model_chunk_id,
|
|
|
|
input_obj,
|
|
|
|
return_tensors,
|
|
|
|
return_output_label=return_output_label,
|
|
|
|
accum_loss=accum_loss)
|
|
|
|
output_objs[model_chunk_id].append(output_obj)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# if forward-only, no need to save tensors for a backward pass
|
|
|
|
if forward_only:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_objs[model_chunk_id].pop()
|
|
|
|
output_objs[model_chunk_id].pop()
|
2021-12-20 15:26:19 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
return output_obj
|
2021-12-20 15:26:19 +00:00
|
|
|
|
2022-04-26 02:00:18 +00:00
|
|
|
def _backward_step_helper(microbatch_id):
|
2021-12-20 15:26:19 +00:00
|
|
|
"""Helper method to run backward step with model split into chunks
|
|
|
|
(run set_virtual_pipeline_model_parallel_rank() before calling
|
|
|
|
backward_step())."""
|
|
|
|
model_chunk_id = get_model_chunk_id(microbatch_id, forward=False)
|
|
|
|
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
|
|
|
|
|
|
|
if gpc.is_pipeline_last_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
if len(output_obj_grads[model_chunk_id]) == 0:
|
|
|
|
output_obj_grads[model_chunk_id].append(None)
|
|
|
|
input_obj = input_objs[model_chunk_id].pop(0)
|
|
|
|
output_obj = output_objs[model_chunk_id].pop(0)
|
|
|
|
output_obj_grad = output_obj_grads[model_chunk_id].pop(0)
|
|
|
|
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
return input_obj_grad
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# Run warmup forward passes.
|
|
|
|
gpc.set_virtual_pipeline_parallel_rank(0)
|
|
|
|
if not gpc.is_pipeline_first_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
|
|
|
|
input_objs[0].append(
|
|
|
|
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype,
|
2022-03-21 08:55:37 +00:00
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors))
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
for k in range(num_warmup_microbatches):
|
|
|
|
model_chunk_id = get_model_chunk_id(k, forward=True)
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj = _forward_step_helper(k)
|
2021-12-20 15:26:19 +00:00
|
|
|
if not gpc.is_pipeline_last_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
if isinstance(output_obj, torch.Tensor):
|
|
|
|
output_obj_shapes[model_chunk_id] = output_obj.shape
|
|
|
|
else:
|
|
|
|
output_obj_shapes[model_chunk_id] = []
|
|
|
|
for out_tensor in output_obj:
|
|
|
|
output_obj_shapes[model_chunk_id].append(out_tensor.shape)
|
|
|
|
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj,
|
|
|
|
send_tensor_shape_flags[model_chunk_id])
|
2021-12-20 15:26:19 +00:00
|
|
|
# Determine if tensor should be received from previous stage.
|
2022-03-21 08:55:37 +00:00
|
|
|
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
|
2021-12-20 15:26:19 +00:00
|
|
|
recv_prev = True
|
|
|
|
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
|
|
|
if next_forward_model_chunk_id == 0:
|
|
|
|
recv_prev = False
|
|
|
|
if k == (num_microbatches - 1):
|
|
|
|
recv_prev = False
|
|
|
|
|
|
|
|
# Don't send tensor downstream if on last stage.
|
|
|
|
if gpc.is_pipeline_last_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj = None
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
|
|
|
if not gpc.is_pipeline_first_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
|
|
|
|
input_obj_shapes[next_forward_model_chunk_id])
|
2021-12-20 15:26:19 +00:00
|
|
|
# Send and receive tensors as appropriate (send tensors computed
|
|
|
|
# in this iteration; receive tensors for next iteration).
|
2022-06-02 05:48:59 +00:00
|
|
|
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
2021-12-20 15:26:19 +00:00
|
|
|
if k == (num_warmup_microbatches - 1) and not forward_only and \
|
|
|
|
not all_warmup_microbatches:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_grad = None
|
2021-12-20 15:26:19 +00:00
|
|
|
recv_next = True
|
|
|
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
|
|
recv_next = False
|
2022-06-02 05:48:59 +00:00
|
|
|
output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
|
|
|
|
input_obj, output_obj_grad = \
|
2022-01-07 05:22:22 +00:00
|
|
|
comm.send_forward_backward_recv_forward_backward(
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj, input_obj_grad,
|
2021-12-20 15:26:19 +00:00
|
|
|
input_shape,
|
|
|
|
output_shape,
|
|
|
|
recv_prev=recv_prev, recv_next=recv_next,
|
2022-01-07 05:22:22 +00:00
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
|
2021-12-20 15:26:19 +00:00
|
|
|
else:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj = \
|
2022-01-07 05:22:22 +00:00
|
|
|
comm.send_forward_recv_forward(
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj,
|
2021-12-20 15:26:19 +00:00
|
|
|
input_shape,
|
|
|
|
recv_prev=recv_prev,
|
2022-01-07 05:22:22 +00:00
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2022-06-02 05:48:59 +00:00
|
|
|
input_objs[next_forward_model_chunk_id].append(input_obj)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# Run 1F1B in steady state.
|
|
|
|
for k in range(num_microbatches_remaining):
|
|
|
|
# Forward pass.
|
|
|
|
forward_k = k + num_warmup_microbatches
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj = _forward_step_helper(forward_k)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# Backward pass.
|
|
|
|
backward_k = k
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_grad = _backward_step_helper(backward_k)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
# Send output_obj and input_obj_grad, receive input_obj
|
|
|
|
# and output_obj_grad.
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# Determine if current stage has anything to send in either direction,
|
2022-06-02 05:48:59 +00:00
|
|
|
# otherwise set obj to None.
|
2021-12-20 15:26:19 +00:00
|
|
|
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
|
|
|
|
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
|
|
|
|
if gpc.is_pipeline_last_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj = None
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
|
|
|
|
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
|
|
|
|
if gpc.is_pipeline_first_stage():
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_grad = None
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# Determine if peers are sending, and where in data structure to put
|
|
|
|
# received tensors.
|
|
|
|
recv_prev = True
|
|
|
|
if gpc.is_pipeline_first_stage(ignore_virtual=True):
|
|
|
|
# First stage is ahead of last stage by (pipeline_parallel_size - 1).
|
2022-03-21 08:55:37 +00:00
|
|
|
next_forward_model_chunk_id = get_model_chunk_id(forward_k - (pipeline_parallel_size - 1), forward=True)
|
2021-12-20 15:26:19 +00:00
|
|
|
if next_forward_model_chunk_id == (num_model_chunks - 1):
|
|
|
|
recv_prev = False
|
|
|
|
next_forward_model_chunk_id += 1
|
|
|
|
else:
|
2022-03-21 08:55:37 +00:00
|
|
|
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
recv_next = True
|
|
|
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
|
|
# Last stage is ahead of first stage by (pipeline_parallel_size - 1).
|
2022-03-21 08:55:37 +00:00
|
|
|
next_backward_model_chunk_id = get_model_chunk_id(backward_k - (pipeline_parallel_size - 1),
|
|
|
|
forward=False)
|
2021-12-20 15:26:19 +00:00
|
|
|
if next_backward_model_chunk_id == 0:
|
|
|
|
recv_next = False
|
|
|
|
next_backward_model_chunk_id -= 1
|
|
|
|
else:
|
2022-03-21 08:55:37 +00:00
|
|
|
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# If last iteration, don't receive; we already received one extra
|
|
|
|
# before the start of the for loop.
|
|
|
|
if k == (num_microbatches_remaining - 1):
|
|
|
|
recv_prev = False
|
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
|
|
|
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
|
|
|
# Communicate objs.
|
|
|
|
input_obj, output_obj_grad = \
|
2022-01-07 05:22:22 +00:00
|
|
|
comm.send_forward_backward_recv_forward_backward(
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj, input_obj_grad,
|
2021-12-20 15:26:19 +00:00
|
|
|
input_shape,
|
|
|
|
output_shape,
|
|
|
|
recv_prev=recv_prev, recv_next=recv_next,
|
2022-01-07 05:22:22 +00:00
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
2022-06-02 05:48:59 +00:00
|
|
|
# Put input_obj and output_obj_grad in data structures in the
|
2021-12-20 15:26:19 +00:00
|
|
|
# right location.
|
|
|
|
if recv_prev:
|
2022-06-02 05:48:59 +00:00
|
|
|
input_objs[next_forward_model_chunk_id].append(input_obj)
|
2021-12-20 15:26:19 +00:00
|
|
|
if recv_next:
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad)
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
# Run cooldown backward passes (flush out pipeline).
|
|
|
|
if not forward_only:
|
|
|
|
if all_warmup_microbatches:
|
2022-06-02 05:48:59 +00:00
|
|
|
output_obj_grads[num_model_chunks - 1].append(
|
|
|
|
comm.recv_backward(output_obj_shapes[num_model_chunks - 1],
|
2022-03-21 08:55:37 +00:00
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors))
|
2021-12-20 15:26:19 +00:00
|
|
|
for k in range(num_microbatches_remaining, num_microbatches):
|
2022-06-02 05:48:59 +00:00
|
|
|
input_obj_grad = _backward_step_helper(k)
|
2022-03-21 08:55:37 +00:00
|
|
|
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
|
2021-12-20 15:26:19 +00:00
|
|
|
recv_next = True
|
|
|
|
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
|
|
if next_backward_model_chunk_id == (num_model_chunks - 1):
|
|
|
|
recv_next = False
|
|
|
|
if k == (num_microbatches - 1):
|
|
|
|
recv_next = False
|
2022-06-02 05:48:59 +00:00
|
|
|
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
|
|
|
output_obj_grads[next_backward_model_chunk_id].append(
|
|
|
|
comm.send_backward_recv_backward(input_obj_grad,
|
2022-03-21 08:55:37 +00:00
|
|
|
output_shape,
|
|
|
|
recv_next=recv_next,
|
|
|
|
dtype=self.dtype,
|
|
|
|
scatter_gather_tensors=self.scatter_gather_tensors))
|
2021-12-20 15:26:19 +00:00
|
|
|
|
|
|
|
if len(return_tensors) > 0:
|
2022-01-17 07:57:47 +00:00
|
|
|
output, label = pack_return_tensors(return_tensors)
|
|
|
|
return output, label, accum_loss
|
2021-12-20 15:26:19 +00:00
|
|
|
else:
|
2022-01-17 07:57:47 +00:00
|
|
|
return None, None, accum_loss
|