mirror of https://github.com/hpcaitech/ColossalAI
pipeline last stage supports multi output (#151)
parent
1ff5be36c2
commit
7bf1e98b97
|
@ -4,7 +4,6 @@
|
||||||
from typing import List, Tuple, Union, Callable
|
from typing import List, Tuple, Union, Callable
|
||||||
import inspect
|
import inspect
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
import colossalai.communication as comm
|
import colossalai.communication as comm
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
@ -14,14 +13,27 @@ from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
from colossalai.zero import (ZeroRedundancyOptimizer_Level_2,
|
||||||
ZeroRedundancyOptimizer_Level_3)
|
ZeroRedundancyOptimizer_Level_3)
|
||||||
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
from colossalai.utils import switch_virtual_pipeline_parallel_rank
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from ._base_schedule import BaseSchedule
|
from ._base_schedule import BaseSchedule
|
||||||
|
|
||||||
|
|
||||||
def squeeze(x: Union[Tensor, tuple, list]):
|
def pack_return_tensors(return_tensors):
|
||||||
if isinstance(x, (tuple, list)):
|
output, label = tuple(zip(*return_tensors))
|
||||||
return x[0]
|
if isinstance(output[0], torch.Tensor):
|
||||||
|
output = torch.cat(output, dim=0)
|
||||||
|
elif isinstance(output[0], (list, tuple)):
|
||||||
|
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
||||||
else:
|
else:
|
||||||
return x
|
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
|
||||||
|
if isinstance(label[0], torch.Tensor):
|
||||||
|
label = torch.cat(label, dim=0)
|
||||||
|
else:
|
||||||
|
merged_label = {k: [] for k in label[0].keys()}
|
||||||
|
for d in label:
|
||||||
|
for k, v in d.items():
|
||||||
|
merged_label[k].append(v)
|
||||||
|
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
|
||||||
|
return output, label
|
||||||
|
|
||||||
|
|
||||||
class PipelineSchedule(BaseSchedule):
|
class PipelineSchedule(BaseSchedule):
|
||||||
|
@ -49,6 +61,7 @@ class PipelineSchedule(BaseSchedule):
|
||||||
self.scatter_gather_tensors = False
|
self.scatter_gather_tensors = False
|
||||||
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
|
||||||
self.scatter_gather_tensors = scatter_gather_tensors
|
self.scatter_gather_tensors = scatter_gather_tensors
|
||||||
|
self._logger = get_dist_logger()
|
||||||
|
|
||||||
def load_batch(self, data_iter):
|
def load_batch(self, data_iter):
|
||||||
# Pipeline schedule just puts data in memory
|
# Pipeline schedule just puts data in memory
|
||||||
|
@ -129,7 +142,6 @@ class PipelineSchedule(BaseSchedule):
|
||||||
"""
|
"""
|
||||||
data, label = self.load_micro_batch()
|
data, label = self.load_micro_batch()
|
||||||
output_tensor = self._call_engine(engine.model, input_tensor, data)
|
output_tensor = self._call_engine(engine.model, input_tensor, data)
|
||||||
output_tensor = squeeze(output_tensor)
|
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
if return_output_label:
|
if return_output_label:
|
||||||
|
@ -139,8 +151,13 @@ class PipelineSchedule(BaseSchedule):
|
||||||
accum_loss.add_(loss_reduced.detach())
|
accum_loss.add_(loss_reduced.detach())
|
||||||
return loss_reduced
|
return loss_reduced
|
||||||
else:
|
else:
|
||||||
|
# forward only, it's useless since backward is not needed
|
||||||
return output_tensor
|
return output_tensor
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
|
||||||
|
self._logger.debug(
|
||||||
|
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
def backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
||||||
|
@ -319,12 +336,10 @@ class PipelineSchedule(BaseSchedule):
|
||||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||||
|
|
||||||
if len(return_tensors) > 0:
|
if len(return_tensors) > 0:
|
||||||
output, label = tuple(map(list, zip(*return_tensors)))
|
output, label = pack_return_tensors(return_tensors)
|
||||||
return (torch.cat(output, dim=0),
|
return output, label, accum_loss
|
||||||
torch.cat(label, dim=0),
|
|
||||||
accum_loss)
|
|
||||||
else:
|
else:
|
||||||
return tuple((None, None, accum_loss))
|
return None, None, accum_loss
|
||||||
|
|
||||||
|
|
||||||
class InterleavedPipelineSchedule(PipelineSchedule):
|
class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
|
@ -389,7 +404,6 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
"""
|
"""
|
||||||
data, label = self.load_micro_batch(model_chunk_id)
|
data, label = self.load_micro_batch(model_chunk_id)
|
||||||
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
|
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
|
||||||
output_tensor = squeeze(output_tensor)
|
|
||||||
|
|
||||||
if gpc.is_pipeline_last_stage():
|
if gpc.is_pipeline_last_stage():
|
||||||
if return_output_label:
|
if return_output_label:
|
||||||
|
@ -399,8 +413,13 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
accum_loss.add_(loss_reduced.detach())
|
accum_loss.add_(loss_reduced.detach())
|
||||||
return loss_reduced
|
return loss_reduced
|
||||||
else:
|
else:
|
||||||
|
# forward only, it's useless since backward is not needed
|
||||||
return output_tensor
|
return output_tensor
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(
|
||||||
|
output_tensor, torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
|
||||||
|
self._logger.debug(
|
||||||
|
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}')
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||||
|
@ -665,9 +684,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
||||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||||
|
|
||||||
if len(return_tensors) > 0:
|
if len(return_tensors) > 0:
|
||||||
output, label = tuple(map(list, zip(*return_tensors)))
|
output, label = pack_return_tensors(return_tensors)
|
||||||
return (torch.cat(output, dim=0),
|
return output, label, accum_loss
|
||||||
torch.cat(label, dim=0),
|
|
||||||
accum_loss)
|
|
||||||
else:
|
else:
|
||||||
return tuple((None, None, accum_loss))
|
return None, None, accum_loss
|
||||||
|
|
|
@ -139,7 +139,7 @@ class VanillaResNet(ModelFromConfig):
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x,
|
return x
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
|
|
Loading…
Reference in New Issue