From 7bf1e98b970defdf69b396913dd04bc4bb729a33 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 17 Jan 2022 15:57:47 +0800 Subject: [PATCH] pipeline last stage supports multi output (#151) --- .../engine/schedule/_pipeline_schedule.py | 51 ++++++++++++------- .../test_pipeline/model/resnet.py | 2 +- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 71e39848f..42a585e08 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -4,7 +4,6 @@ from typing import List, Tuple, Union, Callable import inspect import torch.cuda -from torch import Tensor import colossalai.communication as comm from colossalai.context.parallel_mode import ParallelMode @@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, ZeroRedundancyOptimizer_Level_3) from colossalai.utils import switch_virtual_pipeline_parallel_rank +from colossalai.logging import get_dist_logger from ._base_schedule import BaseSchedule -def squeeze(x: Union[Tensor, tuple, list]): - if isinstance(x, (tuple, list)): - return x[0] +def pack_return_tensors(return_tensors): + output, label = tuple(zip(*return_tensors)) + if isinstance(output[0], torch.Tensor): + output = torch.cat(output, dim=0) + elif isinstance(output[0], (list, tuple)): + output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) else: - return x + raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + if isinstance(label[0], torch.Tensor): + label = torch.cat(label, dim=0) + else: + merged_label = {k: [] for k in label[0].keys()} + for d in label: + for k, v in d.items(): + merged_label[k].append(v) + label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} + return output, label class PipelineSchedule(BaseSchedule): @@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule): self.scatter_gather_tensors = False if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: self.scatter_gather_tensors = scatter_gather_tensors + self._logger = get_dist_logger() def load_batch(self, data_iter): # Pipeline schedule just puts data in memory @@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule): """ data, label = self.load_micro_batch() output_tensor = self._call_engine(engine.model, input_tensor, data) - output_tensor = squeeze(output_tensor) if gpc.is_last_rank(ParallelMode.PIPELINE): if return_output_label: @@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule): accum_loss.add_(loss_reduced.detach()) return loss_reduced else: + # forward only, it's useless since backward is not needed return output_tensor else: + assert isinstance( + output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}') return output_tensor def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad): @@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule): comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) if len(return_tensors) > 0: - output, label = tuple(map(list, zip(*return_tensors))) - return (torch.cat(output, dim=0), - torch.cat(label, dim=0), - accum_loss) + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss else: - return tuple((None, None, accum_loss)) + return None, None, accum_loss class InterleavedPipelineSchedule(PipelineSchedule): @@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule): """ data, label = self.load_micro_batch(model_chunk_id) output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data) - output_tensor = squeeze(output_tensor) if gpc.is_pipeline_last_stage(): if return_output_label: @@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule): accum_loss.add_(loss_reduced.detach()) return loss_reduced else: + # forward only, it's useless since backward is not needed return output_tensor else: + assert isinstance( + output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}') return output_tensor def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): @@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): scatter_gather_tensors=self.scatter_gather_tensors)) if len(return_tensors) > 0: - output, label = tuple(map(list, zip(*return_tensors))) - return (torch.cat(output, dim=0), - torch.cat(label, dim=0), - accum_loss) + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss else: - return tuple((None, None, accum_loss)) + return None, None, accum_loss diff --git a/tests/test_trainer/test_pipeline/model/resnet.py b/tests/test_trainer/test_pipeline/model/resnet.py index ffb158ecc..11d964943 100644 --- a/tests/test_trainer/test_pipeline/model/resnet.py +++ b/tests/test_trainer/test_pipeline/model/resnet.py @@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig): def forward(self, x: Tensor): for layer in self.layers: x = layer(x) - return x, + return x def init_weights(self): for m in self.modules():