pipeline last stage supports multi output (#151)

pull/154/head
ver217 2022-01-17 15:57:47 +08:00 committed by GitHub
parent 1ff5be36c2
commit 7bf1e98b97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 18 deletions

View File

@ -4,7 +4,6 @@
from typing import List, Tuple, Union, Callable
import inspect
import torch.cuda
from torch import Tensor
import colossalai.communication as comm
from colossalai.context.parallel_mode import ParallelMode
@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from ._base_schedule import BaseSchedule
def squeeze(x: Union[Tensor, tuple, list]):
if isinstance(x, (tuple, list)):
return x[0]
def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors))
if isinstance(output[0], torch.Tensor):
output = torch.cat(output, dim=0)
elif isinstance(output[0], (list, tuple)):
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
else:
return x
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
if isinstance(label[0], torch.Tensor):
label = torch.cat(label, dim=0)
else:
merged_label = {k: [] for k in label[0].keys()}
for d in label:
for k, v in d.items():
merged_label[k].append(v)
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
return output, label
class PipelineSchedule(BaseSchedule):
@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule):
self.scatter_gather_tensors = False
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
self.scatter_gather_tensors = scatter_gather_tensors
self._logger = get_dist_logger()
def load_batch(self, data_iter):
# Pipeline schedule just puts data in memory
@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule):
"""
data, label = self.load_micro_batch()
output_tensor = self._call_engine(engine.model, input_tensor, data)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label:
@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule):
accum_loss.add_(loss_reduced.detach())
return loss_reduced
else:
# forward only, it's useless since backward is not needed
return output_tensor
else:
assert isinstance(
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
return output_tensor
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule):
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
accum_loss)
output, label = pack_return_tensors(return_tensors)
return output, label, accum_loss
else:
return tuple((None, None, accum_loss))
return None, None, accum_loss
class InterleavedPipelineSchedule(PipelineSchedule):
@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""
data, label = self.load_micro_batch(model_chunk_id)
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
output_tensor = squeeze(output_tensor)
if gpc.is_pipeline_last_stage():
if return_output_label:
@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
accum_loss.add_(loss_reduced.detach())
return loss_reduced
else:
# forward only, it's useless since backward is not needed
return output_tensor
else:
assert isinstance(
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
return output_tensor
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
scatter_gather_tensors=self.scatter_gather_tensors))
if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors)))
return (torch.cat(output, dim=0),
torch.cat(label, dim=0),
accum_loss)
output, label = pack_return_tensors(return_tensors)
return output, label, accum_loss
else:
return tuple((None, None, accum_loss))
return None, None, accum_loss

View File

@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig):
def forward(self, x: Tensor):
for layer in self.layers:
x = layer(x)
return x,
return x
def init_weights(self):
for m in self.modules():