[pipeline]refactor ppschedule to support tensor list (#1050)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* refactor ppschedule to support tensor list

* polish
pull/1063/head v0.1.6
YuliangLiu0306 2022-06-02 13:48:59 +08:00 committed by GitHub
parent e3fde4ee6b
commit b167258b6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 259 additions and 215 deletions

View File

@ -3,7 +3,7 @@ from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_fo
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
recv_forward, recv_backward)
from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta
from .utils import send_obj_meta, recv_obj_meta
__all__ = [
'all_gather',
@ -21,6 +21,6 @@ __all__ = [
'recv_backward',
'recv_forward',
'ring_forward',
'send_tensor_meta',
'recv_tensor_meta',
'send_obj_meta',
'recv_obj_meta',
]

View File

@ -9,14 +9,21 @@ from typing import Union, List, Tuple
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
def send_meta_helper(obj, next_rank, tensor_kwargs):
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
dist.send(send_ndims, next_rank)
dist.send(send_shape, next_rank)
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
"""Sends obj meta information before sending a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be sent before communications. This function
synchronizes with :func:`recv_obj_meta`.
Args:
tensor (:class:`torch.Tensor`): Tensor to be sent.
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
need_meta (bool, optional): If False, meta information won't be sent.
next_rank (int): The rank of the next member in pipeline parallel group.
@ -28,42 +35,57 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
dist.send(send_ndims, next_rank)
dist.send(send_shape, next_rank)
if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
send_meta_helper(obj, next_rank, tensor_kwargs)
else:
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
dist.send(send_obj_nums, next_rank)
for tensor_to_send in obj:
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
return False
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
"""Receives tensor meta information before receiving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be received before communications. This function
synchronizes with :func:`send_tensor_meta`.
def recv_meta_helper(prev_rank, tensor_kwargs):
recv_ndims = torch.empty((), **tensor_kwargs)
dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.recv(recv_shape, prev_rank)
return recv_shape
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
"""Receives obj meta information before receiving a specific obj.
Since the recipient must know the shape of the obj in p2p communications,
meta information of the obj should be received before communications. This function
synchronizes with :func:`send_obj_meta`.
Args:
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
prev_rank (int): The rank of the source of the tensor.
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
prev_rank (int): The rank of the source of the obj.
Returns:
:class:`torch.Size`: The shape of the tensor to be received.
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
"""
if tensor_shape is None:
if obj_shape is None:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1:
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape = torch.Size(recv_shape)
else:
obj_shape = []
for i in range(recv_obj_nums.item()):
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
obj_shape.append(torch.Size(recv_shape))
recv_ndims = torch.empty((), **tensor_kwargs)
dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.recv(recv_shape, prev_rank)
tensor_shape = torch.Size(recv_shape)
return tensor_shape
return obj_shape
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:

View File

@ -130,7 +130,7 @@ class PipelineSchedule(BaseSchedule):
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
@staticmethod
def _call_engine(model, input_tensor, batch_data):
def _call_engine(model, input_obj, batch_data):
if isinstance(model, NaiveAMPModel):
sig = inspect.signature(model.model.forward)
elif hasattr(model, 'colo_attr'):
@ -140,16 +140,22 @@ class PipelineSchedule(BaseSchedule):
if isinstance(batch_data, torch.Tensor):
for p in sig.parameters.values():
if p.kind == inspect.Parameter.VAR_KEYWORD:
if input_tensor is None:
if input_obj is None:
return model(batch_data)
else:
return model(input_tensor)
if input_tensor is None:
return model(input_obj)
if input_obj is None:
return model(batch_data)
elif len(sig.parameters) > 1:
return model(input_tensor, batch_data)
elif isinstance(input_obj, torch.Tensor):
if len(sig.parameters) > 1:
return model(input_obj, batch_data)
else:
return model(input_obj)
else:
return model(input_tensor)
if len(sig.parameters) > len(input_obj):
return model(*input_obj, batch_data)
else:
return model(*input_obj)
else:
filter_batch = True
for p in sig.parameters.values():
@ -157,79 +163,88 @@ class PipelineSchedule(BaseSchedule):
filter_batch = False
if filter_batch:
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
if input_tensor is None and filter_batch:
if input_obj is None and filter_batch:
return model(**batch_data)
elif isinstance(input_obj, torch.Tensor) or input_obj is None:
return model(input_obj, **batch_data)
else:
return model(input_tensor, **batch_data)
return model(*input_obj, **batch_data)
def _forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores.
Returns:
:class:`torch.Tensor`: output or the loss value of the current pipeline stage.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
"""
data, label = self.load_micro_batch()
output_tensor = self._call_engine(engine.model, input_tensor, data)
output_obj = self._call_engine(engine.model, input_obj, data)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label:
return_tensors.append((output_tensor, label))
return_tensors.append((output_obj, label))
if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
return loss_reduced
else:
# forward only, it's useless since backward is not needed
return output_tensor
return output_obj
else:
assert isinstance(
output_tensor,
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}'
)
return output_tensor
if isinstance(output_obj, torch.Tensor):
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
)
return output_obj
def _backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
"""Backward step through the passed-in output tensor. If it is the last stage, the
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
input_tensor (:class:`torch.Tensor`): input tensor for this pipeline stage.
output_tensor (:class:`torch.Tensor`): output tensor for this pipeline stage.
output_tensor_grad (:class:`torch.Tensor`): gradient of output tensor for this pipeline stage.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
Returns:
:class:`torch.Tensor`: gradient of input tensor.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor.
"""
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Retain the grad on the input_obj.
if input_obj is not None:
if isinstance(input_obj, torch.Tensor):
input_obj.retain_grad()
else:
for in_tensor in input_obj:
if in_tensor is not None:
in_tensor.retain_grad()
# Backward pass.
if output_tensor_grad is None:
engine.backward(output_tensor)
if output_obj_grad is None:
engine.backward(output_obj)
else:
engine.backward_by_grad(output_tensor, output_tensor_grad)
engine.backward_by_grad(output_obj, output_obj_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
if isinstance(input_obj, torch.Tensor):
input_obj_grad = input_obj.grad
else:
input_obj_grad = []
for in_tensor in input_obj:
input_obj_grad.append(in_tensor.grad)
return input_tensor_grad
return input_obj_grad
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
@ -257,108 +272,113 @@ class PipelineSchedule(BaseSchedule):
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
input_objs = None
output_objs = None
if not forward_only:
input_tensors = []
output_tensors = []
input_objs = []
output_objs = []
return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# Used for tensor meta information communication
ft_shape = self.tensor_shape
bt_shape = None
ft_shapes = self.tensor_shape
bt_shapes = None
fs_checker = self.tensor_shape is None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor = self._forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
ft_shapes = comm.recv_obj_meta(ft_shapes)
input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
if isinstance(output_obj, torch.Tensor):
bt_shapes = output_obj.shape
else:
bt_shapes = []
for out_tensor in output_obj:
bt_shapes.append(out_tensor.shape)
fs_checker = comm.send_obj_meta(output_obj, fs_checker)
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = comm.recv_tensor_meta(ft_shape)
input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
ft_shapes = comm.recv_obj_meta(ft_shapes)
input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = self._forward_step(engine,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
if forward_only:
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
if not last_iteration:
input_tensor = comm.recv_forward(ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj = comm.recv_forward(ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
else:
output_tensor_grad = comm.send_forward_recv_backward(output_tensor,
bt_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_obj_grad = comm.send_forward_recv_backward(output_obj,
bt_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop input_tensor and output_tensor from the start of the list for
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_tensor = None
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
input_obj = None
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
else:
input_tensor = comm.send_backward_recv_forward(input_tensor_grad,
ft_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_obj = comm.send_backward_recv_forward(input_obj_grad,
ft_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_tensor_grad = comm.recv_backward(bt_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_obj_grad = comm.recv_backward(bt_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
if len(return_tensors) > 0:
output, label = pack_return_tensors(return_tensors)
@ -426,45 +446,43 @@ class InterleavedPipelineSchedule(PipelineSchedule):
def _forward_step(self,
engine,
model_chunk_id,
input_tensor,
input_obj,
return_tensors,
return_output_label=True,
accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
model_chunk_id (int): The id of model chunks.
input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores.
Returns:
:class:`torch.Tensor`: output or the loss value of the current pipeline stage.
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
"""
data, label = self.load_micro_batch(model_chunk_id)
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
output_obj = self._call_engine(engine.model[model_chunk_id], input_obj, data)
if gpc.is_pipeline_last_stage():
if return_output_label:
return_tensors.append((output_tensor, label))
return_tensors.append((output_obj, label))
if accum_loss is not None:
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
return loss_reduced
else:
# forward only, it's useless since backward is not needed
return output_tensor
return output_obj
else:
assert isinstance(
output_tensor,
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}'
)
return output_tensor
if isinstance(output_obj, torch.Tensor):
self._logger.debug(
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
)
return output_obj
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with
@ -486,19 +504,19 @@ class InterleavedPipelineSchedule(PipelineSchedule):
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch(data_iter)
model = engine.model
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
input_objs = [[] for _ in range(len(model))]
output_objs = [[] for _ in range(len(model))]
return_tensors = []
if not forward_only:
output_tensor_grads = [[] for _ in range(len(model))]
output_obj_grads = [[] for _ in range(len(model))]
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# Used for tensor meta information communication
input_tensor_shapes = [self.tensor_shape for _ in range(len(model))]
output_tensor_shapes = [None for _ in range(len(model))]
# Used for obj meta information communication
input_obj_shapes = [self.tensor_shape for _ in range(len(model))]
output_obj_shapes = [None for _ in range(len(model))]
send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
@ -545,24 +563,24 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# forward step
if gpc.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = self._forward_step(engine,
model_chunk_id,
input_tensor,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_tensors[model_chunk_id].append(output_tensor)
if len(input_objs[model_chunk_id]) == \
len(output_objs[model_chunk_id]):
input_objs[model_chunk_id].append(None)
input_obj = input_objs[model_chunk_id][-1]
output_obj = self._forward_step(engine,
model_chunk_id,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
output_objs[model_chunk_id].append(output_obj)
# if forward-only, no need to save tensors for a backward pass
if forward_only:
input_tensors[model_chunk_id].pop()
output_tensors[model_chunk_id].pop()
input_objs[model_chunk_id].pop()
output_objs[model_chunk_id].pop()
return output_tensor
return output_obj
def _backward_step_helper(microbatch_id):
"""Helper method to run backward step with model split into chunks
@ -572,31 +590,35 @@ class InterleavedPipelineSchedule(PipelineSchedule):
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
if gpc.is_pipeline_last_stage():
if len(output_tensor_grads[model_chunk_id]) == 0:
output_tensor_grads[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
if len(output_obj_grads[model_chunk_id]) == 0:
output_obj_grads[model_chunk_id].append(None)
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = output_obj_grads[model_chunk_id].pop(0)
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
return input_tensor_grad
return input_obj_grad
# Run warmup forward passes.
gpc.set_virtual_pipeline_parallel_rank(0)
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
input_tensors[0].append(
comm.recv_forward(input_tensor_shapes[0],
dtype=self.dtype,
input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
input_objs[0].append(
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_warmup_microbatches):
model_chunk_id = get_model_chunk_id(k, forward=True)
output_tensor = _forward_step_helper(k)
output_obj = _forward_step_helper(k)
if not gpc.is_pipeline_last_stage():
output_tensor_shapes[model_chunk_id] = output_tensor.shape
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor,
send_tensor_shape_flags[model_chunk_id])
if isinstance(output_obj, torch.Tensor):
output_obj_shapes[model_chunk_id] = output_obj.shape
else:
output_obj_shapes[model_chunk_id] = []
for out_tensor in output_obj:
output_obj_shapes[model_chunk_id].append(out_tensor.shape)
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj,
send_tensor_shape_flags[model_chunk_id])
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
recv_prev = True
@ -608,65 +630,65 @@ class InterleavedPipelineSchedule(PipelineSchedule):
# Don't send tensor downstream if on last stage.
if gpc.is_pipeline_last_stage():
output_tensor = None
output_obj = None
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
if not gpc.is_pipeline_first_stage():
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta(
input_tensor_shapes[next_forward_model_chunk_id])
input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
input_obj_shapes[next_forward_model_chunk_id])
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
input_obj_grad = None
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
output_shape = output_tensor_shapes[num_model_chunks - 1] if recv_next else None
input_tensor, output_tensor_grad = \
output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
input_obj, output_obj_grad = \
comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
output_obj, input_obj_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
else:
input_tensor = \
input_obj = \
comm.send_forward_recv_forward(
output_tensor,
output_obj,
input_shape,
recv_prev=recv_prev,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
input_objs[next_forward_model_chunk_id].append(input_obj)
# Run 1F1B in steady state.
for k in range(num_microbatches_remaining):
# Forward pass.
forward_k = k + num_warmup_microbatches
output_tensor = _forward_step_helper(forward_k)
output_obj = _forward_step_helper(forward_k)
# Backward pass.
backward_k = k
input_tensor_grad = _backward_step_helper(backward_k)
input_obj_grad = _backward_step_helper(backward_k)
# Send output_tensor and input_tensor_grad, receive input_tensor
# and output_tensor_grad.
# Send output_obj and input_obj_grad, receive input_obj
# and output_obj_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set tensor to None.
# otherwise set obj to None.
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
if gpc.is_pipeline_last_stage():
output_tensor = None
output_obj = None
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
if gpc.is_pipeline_first_stage():
input_tensor_grad = None
input_obj_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
@ -696,33 +718,33 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if k == (num_microbatches_remaining - 1):
recv_prev = False
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
# Communicate tensors.
input_tensor, output_tensor_grad = \
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
# Communicate objs.
input_obj, output_obj_grad = \
comm.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
output_obj, input_obj_grad,
input_shape,
output_shape,
recv_prev=recv_prev, recv_next=recv_next,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors)
# Put input_tensor and output_tensor_grad in data structures in the
# Put input_obj and output_obj_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
input_objs[next_forward_model_chunk_id].append(input_obj)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad)
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(
comm.recv_backward(output_tensor_shapes[num_model_chunks - 1],
output_obj_grads[num_model_chunks - 1].append(
comm.recv_backward(output_obj_shapes[num_model_chunks - 1],
scatter_gather_tensors=self.scatter_gather_tensors))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = _backward_step_helper(k)
input_obj_grad = _backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
recv_next = True
if gpc.is_pipeline_last_stage(ignore_virtual=True):
@ -730,9 +752,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
recv_next = False
if k == (num_microbatches - 1):
recv_next = False
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
output_tensor_grads[next_backward_model_chunk_id].append(
comm.send_backward_recv_backward(input_tensor_grad,
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
output_obj_grads[next_backward_model_chunk_id].append(
comm.send_backward_recv_backward(input_obj_grad,
output_shape,
recv_next=recv_next,
dtype=self.dtype,

View File

@ -7,9 +7,9 @@ import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.communication import (recv_backward, recv_forward, recv_tensor_meta, send_backward,
from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward,
send_backward_recv_forward, send_forward, send_forward_recv_backward,
send_tensor_meta)
send_obj_meta)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch