[pipeline]refactor ppschedule to support tensor list (#1050)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* refactor ppschedule to support tensor list

* polish
pull/1063/head v0.1.6
YuliangLiu0306 2022-06-02 13:48:59 +08:00 committed by GitHub
parent e3fde4ee6b
commit b167258b6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 259 additions and 215 deletions

View File

@ -3,7 +3,7 @@ from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_fo
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
recv_forward, recv_backward) recv_forward, recv_backward)
from .ring import ring_forward from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta from .utils import send_obj_meta, recv_obj_meta
__all__ = [ __all__ = [
'all_gather', 'all_gather',
@ -21,6 +21,6 @@ __all__ = [
'recv_backward', 'recv_backward',
'recv_forward', 'recv_forward',
'ring_forward', 'ring_forward',
'send_tensor_meta', 'send_obj_meta',
'recv_tensor_meta', 'recv_obj_meta',
] ]

View File

@ -9,14 +9,21 @@ from typing import Union, List, Tuple
TensorShape = Union[torch.Size, List[int], Tuple[int]] TensorShape = Union[torch.Size, List[int], Tuple[int]]
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool: def send_meta_helper(obj, next_rank, tensor_kwargs):
"""Sends tensor meta information before sending a specific tensor. send_shape = torch.tensor(obj.size(), **tensor_kwargs)
Since the recipient must know the shape of the tensor in p2p communications, send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
meta information of the tensor should be sent before communications. This function dist.send(send_ndims, next_rank)
synchronizes with :func:`recv_tensor_meta`. dist.send(send_shape, next_rank)
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
"""Sends obj meta information before sending a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be sent before communications. This function
synchronizes with :func:`recv_obj_meta`.
Args: Args:
tensor (:class:`torch.Tensor`): Tensor to be sent. obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
need_meta (bool, optional): If False, meta information won't be sent. need_meta (bool, optional): If False, meta information won't be sent.
next_rank (int): The rank of the next member in pipeline parallel group. next_rank (int): The rank of the next member in pipeline parallel group.
@ -28,42 +35,57 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
if isinstance(obj, torch.Tensor):
send_shape = torch.tensor(tensor.size(), **tensor_kwargs) send_obj_nums = torch.tensor(1, **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) dist.send(send_obj_nums, next_rank)
dist.send(send_ndims, next_rank) send_meta_helper(obj, next_rank, tensor_kwargs)
dist.send(send_shape, next_rank) else:
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
for tensor_to_send in obj:
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
return False return False
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size: def recv_meta_helper(prev_rank, tensor_kwargs):
"""Receives tensor meta information before receiving a specific tensor. recv_ndims = torch.empty((), **tensor_kwargs)
Since the recipient must know the shape of the tensor in p2p communications, dist.recv(recv_ndims, prev_rank)
meta information of the tensor should be received before communications. This function recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
synchronizes with :func:`send_tensor_meta`. dist.recv(recv_shape, prev_rank)
return recv_shape
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
"""Receives obj meta information before receiving a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be received before communications. This function
synchronizes with :func:`send_obj_meta`.
Args: Args:
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received. obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
prev_rank (int): The rank of the source of the tensor. prev_rank (int): The rank of the source of the obj.
Returns: Returns:
:class:`torch.Size`: The shape of the tensor to be received. Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
""" """
if tensor_shape is None: if obj_shape is None:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1:
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape = torch.Size(recv_shape)
else:
obj_shape = []
for i in range(recv_obj_nums.item()):
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape.append(torch.Size(recv_shape))
recv_ndims = torch.empty((), **tensor_kwargs) return obj_shape
dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.recv(recv_shape, prev_rank)
tensor_shape = torch.Size(recv_shape)
return tensor_shape
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:

View File

@ -130,7 +130,7 @@ class PipelineSchedule(BaseSchedule):
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
@staticmethod @staticmethod
def _call_engine(model, input_tensor, batch_data): def _call_engine(model, input_obj, batch_data):
if isinstance(model, NaiveAMPModel): if isinstance(model, NaiveAMPModel):
sig = inspect.signature(model.model.forward) sig = inspect.signature(model.model.forward)
elif hasattr(model, 'colo_attr'): elif hasattr(model, 'colo_attr'):
@ -140,16 +140,22 @@ class PipelineSchedule(BaseSchedule):
if isinstance(batch_data, torch.Tensor): if isinstance(batch_data, torch.Tensor):
for p in sig.parameters.values(): for p in sig.parameters.values():
if p.kind == inspect.Parameter.VAR_KEYWORD: if p.kind == inspect.Parameter.VAR_KEYWORD:
if input_tensor is None: if input_obj is None:
return model(batch_data) return model(batch_data)
else: else:
return model(input_tensor) return model(input_obj)
if input_tensor is None: if input_obj is None:
return model(batch_data) return model(batch_data)
elif len(sig.parameters) > 1: elif isinstance(input_obj, torch.Tensor):
return model(input_tensor, batch_data) if len(sig.parameters) > 1:
return model(input_obj, batch_data)
else:
return model(input_obj)
else: else:
return model(input_tensor) if len(sig.parameters) > len(input_obj):
return model(*input_obj, batch_data)
else:
return model(*input_obj)
else: else:
filter_batch = True filter_batch = True
for p in sig.parameters.values(): for p in sig.parameters.values():
@ -157,79 +163,88 @@ class PipelineSchedule(BaseSchedule):
filter_batch = False filter_batch = False
if filter_batch: if filter_batch:
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters} batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
if input_tensor is None and filter_batch: if input_obj is None and filter_batch:
return model(**batch_data) return model(**batch_data)
elif isinstance(input_obj, torch.Tensor) or input_obj is None:
return model(input_obj, **batch_data)
else: else:
return model(input_tensor, **batch_data) return model(*input_obj, **batch_data)
def _forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None): def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.engine.Engine): Colossalai engine for training and inference.
input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage. input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels. return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores. accum_loss (optional): Where accumulated loss stores.
Returns: Returns:
:class:`torch.Tensor`: output or the loss value of the current pipeline stage. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
""" """
data, label = self.load_micro_batch() data, label = self.load_micro_batch()
output_tensor = self._call_engine(engine.model, input_tensor, data) output_obj = self._call_engine(engine.model, input_obj, data)
if gpc.is_last_rank(ParallelMode.PIPELINE): if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label: if return_output_label:
return_tensors.append((output_tensor, label)) return_tensors.append((output_obj, label))
if accum_loss is not None: if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
return loss_reduced return loss_reduced
else: else:
# forward only, it's useless since backward is not needed # forward only, it's useless since backward is not needed
return output_tensor return output_obj
else: else:
assert isinstance( if isinstance(output_obj, torch.Tensor):
output_tensor, self._logger.debug(
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
self._logger.debug( )
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}' return output_obj
)
return output_tensor
def _backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad): def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the """Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage). Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users. This is a helper function and can be ignored by users.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.engine.Engine): Colossalai engine for training and inference.
input_tensor (:class:`torch.Tensor`): input tensor for this pipeline stage. input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
output_tensor (:class:`torch.Tensor`): output tensor for this pipeline stage. output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
output_tensor_grad (:class:`torch.Tensor`): gradient of output tensor for this pipeline stage. output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
Returns: Returns:
:class:`torch.Tensor`: gradient of input tensor. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor.
""" """
# Retain the grad on the input_tensor. # Retain the grad on the input_obj.
if input_tensor is not None: if input_obj is not None:
input_tensor.retain_grad() if isinstance(input_obj, torch.Tensor):
input_obj.retain_grad()
else:
for in_tensor in input_obj:
if in_tensor is not None:
in_tensor.retain_grad()
# Backward pass. # Backward pass.
if output_tensor_grad is None: if output_obj_grad is None:
engine.backward(output_tensor) engine.backward(output_obj)
else: else:
engine.backward_by_grad(output_tensor, output_tensor_grad) engine.backward_by_grad(output_obj, output_obj_grad)
# Collect the grad of the input_tensor. # Collect the grad of the input_obj.
input_tensor_grad = None input_obj_grad = None
if input_tensor is not None: if input_obj is not None:
input_tensor_grad = input_tensor.grad if isinstance(input_obj, torch.Tensor):
input_obj_grad = input_obj.grad
else:
input_obj_grad = []
for in_tensor in input_obj:
input_obj_grad.append(in_tensor.grad)
return input_tensor_grad return input_obj_grad
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
@ -257,108 +272,113 @@ class PipelineSchedule(BaseSchedule):
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes # Input, output tensors only need to be saved when doing backward passes
input_tensors = None input_objs = None
output_tensors = None output_objs = None
if not forward_only: if not forward_only:
input_tensors = [] input_objs = []
output_tensors = [] output_objs = []
return_tensors = [] return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
else: else:
accum_loss = None accum_loss = None
# Used for tensor meta information communication # Used for tensor meta information communication
ft_shape = self.tensor_shape ft_shapes = self.tensor_shape
bt_shape = None bt_shapes = None
fs_checker = self.tensor_shape is None fs_checker = self.tensor_shape is None
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = comm.recv_tensor_meta(ft_shape) ft_shapes = comm.recv_obj_meta(ft_shapes)
input_tensor = comm.recv_forward(ft_shape, input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor = self._forward_step(engine, output_obj = self._forward_step(engine,
input_tensor, input_obj,
return_tensors, return_tensors,
return_output_label=return_output_label, return_output_label=return_output_label,
accum_loss=accum_loss) accum_loss=accum_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape if isinstance(output_obj, torch.Tensor):
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker) bt_shapes = output_obj.shape
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) else:
bt_shapes = []
for out_tensor in output_obj:
bt_shapes.append(out_tensor.shape)
fs_checker = comm.send_obj_meta(output_obj, fs_checker)
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_objs.append(input_obj)
output_tensors.append(output_tensor) output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = comm.recv_tensor_meta(ft_shape) ft_shapes = comm.recv_obj_meta(ft_shapes)
input_tensor = comm.recv_forward(ft_shape, input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1)) last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self._forward_step(engine, output_obj = self._forward_step(engine,
input_tensor, input_obj,
return_tensors, return_tensors,
return_output_label=return_output_label, return_output_label=return_output_label,
accum_loss=accum_loss) accum_loss=accum_loss)
if forward_only: if forward_only:
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
if not last_iteration: if not last_iteration:
input_tensor = comm.recv_forward(ft_shape, input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
else: else:
output_tensor_grad = comm.send_forward_recv_backward(output_tensor, output_obj_grad = comm.send_forward_recv_backward(output_obj,
bt_shape, bt_shapes,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
# Add input_tensor and output_tensor to end of list. # Add input_obj and output_obj to end of list.
input_tensors.append(input_tensor) input_objs.append(input_obj)
output_tensors.append(output_tensor) output_objs.append(output_obj)
# Pop input_tensor and output_tensor from the start of the list for # Pop output_obj and output_obj from the start of the list for
# the backward pass. # the backward pass.
input_tensor = input_tensors.pop(0) input_obj = input_objs.pop(0)
output_tensor = output_tensors.pop(0) output_obj = output_objs.pop(0)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad) input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
if last_iteration: if last_iteration:
input_tensor = None input_obj = None
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
else: else:
input_tensor = comm.send_backward_recv_forward(input_tensor_grad, input_obj = comm.send_backward_recv_forward(input_obj_grad,
ft_shape, ft_shapes,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0) input_obj = input_objs.pop(0)
output_tensor = output_tensors.pop(0) output_obj = output_objs.pop(0)
output_tensor_grad = comm.recv_backward(bt_shape, output_obj_grad = comm.recv_backward(bt_shapes,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad) input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
if len(return_tensors) > 0: if len(return_tensors) > 0:
output, label = pack_return_tensors(return_tensors) output, label = pack_return_tensors(return_tensors)
@ -426,45 +446,43 @@ class InterleavedPipelineSchedule(PipelineSchedule):
def _forward_step(self, def _forward_step(self,
engine, engine,
model_chunk_id, model_chunk_id,
input_tensor, input_obj,
return_tensors, return_tensors,
return_output_label=True, return_output_label=True,
accum_loss=None): accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
Args: Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference. engine (colossalai.engine.Engine): Colossalai engine for training and inference.
model_chunk_id (int): The id of model chunks. model_chunk_id (int): The id of model chunks.
input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage. input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels. return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores. accum_loss (optional): Where accumulated loss stores.
Returns: Returns:
:class:`torch.Tensor`: output or the loss value of the current pipeline stage. Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
""" """
data, label = self.load_micro_batch(model_chunk_id) data, label = self.load_micro_batch(model_chunk_id)
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data) output_obj = self._call_engine(engine.model[model_chunk_id], input_obj, data)
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
if return_output_label: if return_output_label:
return_tensors.append((output_tensor, label)) return_tensors.append((output_obj, label))
if accum_loss is not None: if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach()) accum_loss.add_(loss_reduced.detach())
return loss_reduced return loss_reduced
else: else:
# forward only, it's useless since backward is not needed # forward only, it's useless since backward is not needed
return output_tensor return output_obj
else: else:
assert isinstance( if isinstance(output_obj, torch.Tensor):
output_tensor, self._logger.debug(
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
self._logger.debug( )
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}' return output_obj
)
return output_tensor
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
@ -486,19 +504,19 @@ class InterleavedPipelineSchedule(PipelineSchedule):
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch(data_iter) self.load_batch(data_iter)
model = engine.model model = engine.model
input_tensors = [[] for _ in range(len(model))] input_objs = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_objs = [[] for _ in range(len(model))]
return_tensors = [] return_tensors = []
if not forward_only: if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))] output_obj_grads = [[] for _ in range(len(model))]
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
else: else:
accum_loss = None accum_loss = None
# Used for tensor meta information communication # Used for obj meta information communication
input_tensor_shapes = [self.tensor_shape for _ in range(len(model))] input_obj_shapes = [self.tensor_shape for _ in range(len(model))]
output_tensor_shapes = [None for _ in range(len(model))] output_obj_shapes = [None for _ in range(len(model))]
send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
@ -545,24 +563,24 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# forward step # forward step
if gpc.is_pipeline_first_stage(): if gpc.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \ if len(input_objs[model_chunk_id]) == \
len(output_tensors[model_chunk_id]): len(output_objs[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_objs[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1] input_obj = input_objs[model_chunk_id][-1]
output_tensor = self._forward_step(engine, output_obj = self._forward_step(engine,
model_chunk_id, model_chunk_id,
input_tensor, input_obj,
return_tensors, return_tensors,
return_output_label=return_output_label, return_output_label=return_output_label,
accum_loss=accum_loss) accum_loss=accum_loss)
output_tensors[model_chunk_id].append(output_tensor) output_objs[model_chunk_id].append(output_obj)
# if forward-only, no need to save tensors for a backward pass # if forward-only, no need to save tensors for a backward pass
if forward_only: if forward_only:
input_tensors[model_chunk_id].pop() input_objs[model_chunk_id].pop()
output_tensors[model_chunk_id].pop() output_objs[model_chunk_id].pop()
return output_tensor return output_obj
def _backward_step_helper(microbatch_id): def _backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks """Helper method to run backward step with model split into chunks
@ -572,31 +590,35 @@ class InterleavedPipelineSchedule(PipelineSchedule):
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0: if len(output_obj_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None) output_obj_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0) input_obj = input_objs[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0) output_obj = output_objs[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) output_obj_grad = output_obj_grads[model_chunk_id].pop(0)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad) input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
return input_tensor_grad return input_obj_grad
# Run warmup forward passes. # Run warmup forward passes.
gpc.set_virtual_pipeline_parallel_rank(0) gpc.set_virtual_pipeline_parallel_rank(0)
if not gpc.is_pipeline_first_stage(): if not gpc.is_pipeline_first_stage():
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0]) input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
input_tensors[0].append( input_objs[0].append(
comm.recv_forward(input_tensor_shapes[0], comm.recv_forward(input_obj_shapes[0], dtype=self.dtype,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)) scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True) model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = _forward_step_helper(k) output_obj = _forward_step_helper(k)
if not gpc.is_pipeline_last_stage(): if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape if isinstance(output_obj, torch.Tensor):
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor, output_obj_shapes[model_chunk_id] = output_obj.shape
send_tensor_shape_flags[model_chunk_id]) else:
output_obj_shapes[model_chunk_id] = []
for out_tensor in output_obj:
output_obj_shapes[model_chunk_id].append(out_tensor.shape)
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj,
send_tensor_shape_flags[model_chunk_id])
# Determine if tensor should be received from previous stage. # Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True recv_prev = True
@ -608,65 +630,65 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# Don't send tensor downstream if on last stage. # Don't send tensor downstream if on last stage.
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
output_tensor = None output_obj = None
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
if not gpc.is_pipeline_first_stage(): if not gpc.is_pipeline_first_stage():
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta( input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
input_tensor_shapes[next_forward_model_chunk_id]) input_obj_shapes[next_forward_model_chunk_id])
# Send and receive tensors as appropriate (send tensors computed # Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration). # in this iteration; receive tensors for next iteration).
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
if k == (num_warmup_microbatches - 1) and not forward_only and \ if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches: not all_warmup_microbatches:
input_tensor_grad = None input_obj_grad = None
recv_next = True recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True): if gpc.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False recv_next = False
output_shape = output_tensor_shapes[num_model_chunks - 1] if recv_next else None output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
input_tensor, output_tensor_grad = \ input_obj, output_obj_grad = \
comm.send_forward_backward_recv_forward_backward( comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_obj, input_obj_grad,
input_shape, input_shape,
output_shape, output_shape,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
else: else:
input_tensor = \ input_obj = \
comm.send_forward_recv_forward( comm.send_forward_recv_forward(
output_tensor, output_obj,
input_shape, input_shape,
recv_prev=recv_prev, recv_prev=recv_prev,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_objs[next_forward_model_chunk_id].append(input_obj)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for k in range(num_microbatches_remaining): for k in range(num_microbatches_remaining):
# Forward pass. # Forward pass.
forward_k = k + num_warmup_microbatches forward_k = k + num_warmup_microbatches
output_tensor = _forward_step_helper(forward_k) output_obj = _forward_step_helper(forward_k)
# Backward pass. # Backward pass.
backward_k = k backward_k = k
input_tensor_grad = _backward_step_helper(backward_k) input_obj_grad = _backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor # Send output_obj and input_obj_grad, receive input_obj
# and output_tensor_grad. # and output_obj_grad.
# Determine if current stage has anything to send in either direction, # Determine if current stage has anything to send in either direction,
# otherwise set tensor to None. # otherwise set obj to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
output_tensor = None output_obj = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
if gpc.is_pipeline_first_stage(): if gpc.is_pipeline_first_stage():
input_tensor_grad = None input_obj_grad = None
# Determine if peers are sending, and where in data structure to put # Determine if peers are sending, and where in data structure to put
# received tensors. # received tensors.
@ -696,33 +718,33 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if k == (num_microbatches_remaining - 1): if k == (num_microbatches_remaining - 1):
recv_prev = False recv_prev = False
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
# Communicate tensors. # Communicate objs.
input_tensor, output_tensor_grad = \ input_obj, output_obj_grad = \
comm.send_forward_backward_recv_forward_backward( comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_obj, input_obj_grad,
input_shape, input_shape,
output_shape, output_shape,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype, dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors) scatter_gather_tensors=self.scatter_gather_tensors)
# Put input_tensor and output_tensor_grad in data structures in the # Put input_obj and output_obj_grad in data structures in the
# right location. # right location.
if recv_prev: if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_objs[next_forward_model_chunk_id].append(input_obj)
if recv_next: if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad)
# Run cooldown backward passes (flush out pipeline). # Run cooldown backward passes (flush out pipeline).
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append( output_obj_grads[num_model_chunks - 1].append(
comm.recv_backward(output_tensor_shapes[num_model_chunks - 1], comm.recv_backward(output_obj_shapes[num_model_chunks - 1],
scatter_gather_tensors=self.scatter_gather_tensors)) scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = _backward_step_helper(k) input_obj_grad = _backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True): if gpc.is_pipeline_last_stage(ignore_virtual=True):
@ -730,9 +752,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False recv_next = False
if k == (num_microbatches - 1): if k == (num_microbatches - 1):
recv_next = False recv_next = False
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
output_tensor_grads[next_backward_model_chunk_id].append( output_obj_grads[next_backward_model_chunk_id].append(
comm.send_backward_recv_backward(input_tensor_grad, comm.send_backward_recv_backward(input_obj_grad,
output_shape, output_shape,
recv_next=recv_next, recv_next=recv_next,
dtype=self.dtype, dtype=self.dtype,

View File

@ -7,9 +7,9 @@ import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.communication import (recv_backward, recv_forward, recv_tensor_meta, send_backward, from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward,
send_backward_recv_forward, send_forward, send_forward_recv_backward, send_backward_recv_forward, send_forward, send_forward_recv_backward,
send_tensor_meta) send_obj_meta)
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.initialize import launch from colossalai.initialize import launch