mirror of https://github.com/hpcaitech/ColossalAI
pipeline last stage supports multi output (#151)
parent
1ff5be36c2
commit
7bf1e98b97
|
@ -4,7 +4,6 @@
|
|||
from typing import List, Tuple, Union, Callable
|
||||
import inspect
|
||||
import torch.cuda
|
||||
from torch import Tensor
|
||||
|
||||
import colossalai.communication as comm
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||
ZeroRedundancyOptimizer_Level_3)
|
||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||
from colossalai.logging import get_dist_logger
|
||||
from ._base_schedule import BaseSchedule
|
||||
|
||||
|
||||
def squeeze(x: Union[Tensor, tuple, list]):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return x[0]
|
||||
def pack_return_tensors(return_tensors):
|
||||
output, label = tuple(zip(*return_tensors))
|
||||
if isinstance(output[0], torch.Tensor):
|
||||
output = torch.cat(output, dim=0)
|
||||
elif isinstance(output[0], (list, tuple)):
|
||||
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
||||
else:
|
||||
return x
|
||||
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
|
||||
if isinstance(label[0], torch.Tensor):
|
||||
label = torch.cat(label, dim=0)
|
||||
else:
|
||||
merged_label = {k: [] for k in label[0].keys()}
|
||||
for d in label:
|
||||
for k, v in d.items():
|
||||
merged_label[k].append(v)
|
||||
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
|
||||
return output, label
|
||||
|
||||
|
||||
class PipelineSchedule(BaseSchedule):
|
||||
|
@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
self.scatter_gather_tensors = False
|
||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
||||
self.scatter_gather_tensors = scatter_gather_tensors
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
# Pipeline schedule just puts data in memory
|
||||
|
@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule):
|
|||
"""
|
||||
data, label = self.load_micro_batch()
|
||||
output_tensor = self._call_engine(engine.model, input_tensor, data)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_output_label:
|
||||
|
@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule):
|
|||
accum_loss.add_(loss_reduced.detach())
|
||||
return loss_reduced
|
||||
else:
|
||||
# forward only, it's useless since backward is not needed
|
||||
return output_tensor
|
||||
else:
|
||||
assert isinstance(
|
||||
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
|
||||
return output_tensor
|
||||
|
||||
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
||||
|
@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule):
|
|||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
accum_loss)
|
||||
output, label = pack_return_tensors(return_tensors)
|
||||
return output, label, accum_loss
|
||||
else:
|
||||
return tuple((None, None, accum_loss))
|
||||
return None, None, accum_loss
|
||||
|
||||
|
||||
class InterleavedPipelineSchedule(PipelineSchedule):
|
||||
|
@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
"""
|
||||
data, label = self.load_micro_batch(model_chunk_id)
|
||||
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
|
||||
output_tensor = squeeze(output_tensor)
|
||||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
if return_output_label:
|
||||
|
@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
accum_loss.add_(loss_reduced.detach())
|
||||
return loss_reduced
|
||||
else:
|
||||
# forward only, it's useless since backward is not needed
|
||||
return output_tensor
|
||||
else:
|
||||
assert isinstance(
|
||||
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
|
||||
return output_tensor
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
|
@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = tuple(map(list, zip(*return_tensors)))
|
||||
return (torch.cat(output, dim=0),
|
||||
torch.cat(label, dim=0),
|
||||
accum_loss)
|
||||
output, label = pack_return_tensors(return_tensors)
|
||||
return output, label, accum_loss
|
||||
else:
|
||||
return tuple((None, None, accum_loss))
|
||||
return None, None, accum_loss
|
||||
|
|
|
@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig):
|
|||
def forward(self, x: Tensor):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x,
|
||||
return x
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
|
|
Loading…
Reference in New Issue