pipeline last stage supports multi output (#151)

pull/154/head
ver217 2022-01-17 15:57:47 +08:00 committed by GitHub
parent 1ff5be36c2
commit 7bf1e98b97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 18 deletions

View File

@ -4,7 +4,6 @@
from typing import List, Tuple, Union, Callable from typing import List, Tuple, Union, Callable
import inspect import inspect
import torch.cuda import torch.cuda
from torch import Tensor
import colossalai.communication as comm import colossalai.communication as comm
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2, from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
ZeroRedundancyOptimizer_Level_3) ZeroRedundancyOptimizer_Level_3)
from colossalai.utils import switch_virtual_pipeline_parallel_rank from colossalai.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
def squeeze(x: Union[Tensor, tuple, list]): def pack_return_tensors(return_tensors):
if isinstance(x, (tuple, list)): output, label = tuple(zip(*return_tensors))
return x[0] if isinstance(output[0], torch.Tensor):
output = torch.cat(output, dim=0)
elif isinstance(output[0], (list, tuple)):
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
else: else:
return x raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
if isinstance(label[0], torch.Tensor):
label = torch.cat(label, dim=0)
else:
merged_label = {k: [] for k in label[0].keys()}
for d in label:
for k, v in d.items():
merged_label[k].append(v)
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
return output, label
class PipelineSchedule(BaseSchedule): class PipelineSchedule(BaseSchedule):
@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule):
self.scatter_gather_tensors = False self.scatter_gather_tensors = False
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
self.scatter_gather_tensors = scatter_gather_tensors self.scatter_gather_tensors = scatter_gather_tensors
self._logger = get_dist_logger()
def load_batch(self, data_iter): def load_batch(self, data_iter):
# Pipeline schedule just puts data in memory # Pipeline schedule just puts data in memory
@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule):
""" """
data, label = self.load_micro_batch() data, label = self.load_micro_batch()
output_tensor = self._call_engine(engine.model, input_tensor, data) output_tensor = self._call_engine(engine.model, input_tensor, data)
output_tensor = squeeze(output_tensor)
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label: if return_output_label:
@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule):
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
return loss_reduced return loss_reduced
else: else:
# forward only, it's useless since backward is not needed
return output_tensor return output_tensor
else: else:
assert isinstance(
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
return output_tensor return output_tensor
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad): def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule):
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
if len(return_tensors) > 0: if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors))) output, label = pack_return_tensors(return_tensors)
return (torch.cat(output, dim=0), return output, label, accum_loss
torch.cat(label, dim=0),
accum_loss)
else: else:
return tuple((None, None, accum_loss)) return None, None, accum_loss
class InterleavedPipelineSchedule(PipelineSchedule): class InterleavedPipelineSchedule(PipelineSchedule):
@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule):
""" """
data, label = self.load_micro_batch(model_chunk_id) data, label = self.load_micro_batch(model_chunk_id)
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data) output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
output_tensor = squeeze(output_tensor)
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
if return_output_label: if return_output_label:
@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
return loss_reduced return loss_reduced
else: else:
# forward only, it's useless since backward is not needed
return output_tensor return output_tensor
else: else:
assert isinstance(
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
return output_tensor return output_tensor
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
scatter_gather_tensors=self.scatter_gather_tensors)) scatter_gather_tensors=self.scatter_gather_tensors))
if len(return_tensors) > 0: if len(return_tensors) > 0:
output, label = tuple(map(list, zip(*return_tensors))) output, label = pack_return_tensors(return_tensors)
return (torch.cat(output, dim=0), return output, label, accum_loss
torch.cat(label, dim=0),
accum_loss)
else: else:
return tuple((None, None, accum_loss)) return None, None, accum_loss

View File

@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig):
def forward(self, x: Tensor): def forward(self, x: Tensor):
for layer in self.layers: for layer in self.layers:
x = layer(x) x = layer(x)
return x, return x
def init_weights(self): def init_weights(self):
for m in self.modules(): for m in self.modules():