mirror of https://github.com/hpcaitech/ColossalAI
[pipeline]refactor ppschedule to support tensor list (#1050)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* refactor ppschedule to support tensor list
* polish
pull/1063/head
v0.1.6
parent
e3fde4ee6b
commit
b167258b6a
|
@ -3,7 +3,7 @@ from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_fo
|
|||
send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward,
|
||||
recv_forward, recv_backward)
|
||||
from .ring import ring_forward
|
||||
from .utils import send_tensor_meta, recv_tensor_meta
|
||||
from .utils import send_obj_meta, recv_obj_meta
|
||||
|
||||
__all__ = [
|
||||
'all_gather',
|
||||
|
@ -21,6 +21,6 @@ __all__ = [
|
|||
'recv_backward',
|
||||
'recv_forward',
|
||||
'ring_forward',
|
||||
'send_tensor_meta',
|
||||
'recv_tensor_meta',
|
||||
'send_obj_meta',
|
||||
'recv_obj_meta',
|
||||
]
|
||||
|
|
|
@ -9,14 +9,21 @@ from typing import Union, List, Tuple
|
|||
TensorShape = Union[torch.Size, List[int], Tuple[int]]
|
||||
|
||||
|
||||
def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
|
||||
"""Sends tensor meta information before sending a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be sent before communications. This function
|
||||
synchronizes with :func:`recv_tensor_meta`.
|
||||
def send_meta_helper(obj, next_rank, tensor_kwargs):
|
||||
send_shape = torch.tensor(obj.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs)
|
||||
dist.send(send_ndims, next_rank)
|
||||
dist.send(send_shape, next_rank)
|
||||
|
||||
|
||||
def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
|
||||
"""Sends obj meta information before sending a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be sent before communications. This function
|
||||
synchronizes with :func:`recv_obj_meta`.
|
||||
|
||||
Args:
|
||||
tensor (:class:`torch.Tensor`): Tensor to be sent.
|
||||
obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent.
|
||||
need_meta (bool, optional): If False, meta information won't be sent.
|
||||
next_rank (int): The rank of the next member in pipeline parallel group.
|
||||
|
||||
|
@ -28,42 +35,57 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool:
|
|||
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
|
||||
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
|
||||
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
|
||||
dist.send(send_ndims, next_rank)
|
||||
dist.send(send_shape, next_rank)
|
||||
if isinstance(obj, torch.Tensor):
|
||||
send_obj_nums = torch.tensor(1, **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
send_meta_helper(obj, next_rank, tensor_kwargs)
|
||||
else:
|
||||
send_obj_nums = torch.tensor(len(obj), **tensor_kwargs)
|
||||
dist.send(send_obj_nums, next_rank)
|
||||
for tensor_to_send in obj:
|
||||
send_meta_helper(tensor_to_send, next_rank, tensor_kwargs)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size:
|
||||
"""Receives tensor meta information before receiving a specific tensor.
|
||||
Since the recipient must know the shape of the tensor in p2p communications,
|
||||
meta information of the tensor should be received before communications. This function
|
||||
synchronizes with :func:`send_tensor_meta`.
|
||||
def recv_meta_helper(prev_rank, tensor_kwargs):
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_ndims, prev_rank)
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.recv(recv_shape, prev_rank)
|
||||
return recv_shape
|
||||
|
||||
|
||||
def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
|
||||
"""Receives obj meta information before receiving a specific obj.
|
||||
Since the recipient must know the shape of the obj in p2p communications,
|
||||
meta information of the obj should be received before communications. This function
|
||||
synchronizes with :func:`send_obj_meta`.
|
||||
|
||||
Args:
|
||||
tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
|
||||
prev_rank (int): The rank of the source of the tensor.
|
||||
obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received.
|
||||
prev_rank (int): The rank of the source of the obj.
|
||||
|
||||
Returns:
|
||||
:class:`torch.Size`: The shape of the tensor to be received.
|
||||
Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received.
|
||||
"""
|
||||
if tensor_shape is None:
|
||||
if obj_shape is None:
|
||||
if prev_rank is None:
|
||||
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
|
||||
|
||||
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
|
||||
recv_obj_nums = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_obj_nums, prev_rank)
|
||||
if recv_obj_nums.item() == 1:
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape = torch.Size(recv_shape)
|
||||
else:
|
||||
obj_shape = []
|
||||
for i in range(recv_obj_nums.item()):
|
||||
recv_shape = recv_meta_helper(prev_rank, tensor_kwargs)
|
||||
obj_shape.append(torch.Size(recv_shape))
|
||||
|
||||
recv_ndims = torch.empty((), **tensor_kwargs)
|
||||
dist.recv(recv_ndims, prev_rank)
|
||||
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
|
||||
dist.recv(recv_shape, prev_rank)
|
||||
|
||||
tensor_shape = torch.Size(recv_shape)
|
||||
|
||||
return tensor_shape
|
||||
return obj_shape
|
||||
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor:
|
||||
|
|
|
@ -130,7 +130,7 @@ class PipelineSchedule(BaseSchedule):
|
|||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(model, input_tensor, batch_data):
|
||||
def _call_engine(model, input_obj, batch_data):
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
sig = inspect.signature(model.model.forward)
|
||||
elif hasattr(model, 'colo_attr'):
|
||||
|
@ -140,16 +140,22 @@ class PipelineSchedule(BaseSchedule):
|
|||
if isinstance(batch_data, torch.Tensor):
|
||||
for p in sig.parameters.values():
|
||||
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
if input_tensor is None:
|
||||
if input_obj is None:
|
||||
return model(batch_data)
|
||||
else:
|
||||
return model(input_tensor)
|
||||
if input_tensor is None:
|
||||
return model(input_obj)
|
||||
if input_obj is None:
|
||||
return model(batch_data)
|
||||
elif len(sig.parameters) > 1:
|
||||
return model(input_tensor, batch_data)
|
||||
elif isinstance(input_obj, torch.Tensor):
|
||||
if len(sig.parameters) > 1:
|
||||
return model(input_obj, batch_data)
|
||||
else:
|
||||
return model(input_obj)
|
||||
else:
|
||||
return model(input_tensor)
|
||||
if len(sig.parameters) > len(input_obj):
|
||||
return model(*input_obj, batch_data)
|
||||
else:
|
||||
return model(*input_obj)
|
||||
else:
|
||||
filter_batch = True
|
||||
for p in sig.parameters.values():
|
||||
|
@ -157,79 +163,88 @@ class PipelineSchedule(BaseSchedule):
|
|||
filter_batch = False
|
||||
if filter_batch:
|
||||
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
|
||||
if input_tensor is None and filter_batch:
|
||||
if input_obj is None and filter_batch:
|
||||
return model(**batch_data)
|
||||
elif isinstance(input_obj, torch.Tensor) or input_obj is None:
|
||||
return model(input_obj, **batch_data)
|
||||
else:
|
||||
return model(input_tensor, **batch_data)
|
||||
return model(*input_obj, **batch_data)
|
||||
|
||||
def _forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None):
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage.
|
||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
||||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||
return_output_label (bool, optional): Whether returns output labels.
|
||||
accum_loss (optional): Where accumulated loss stores.
|
||||
Returns:
|
||||
:class:`torch.Tensor`: output or the loss value of the current pipeline stage.
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
|
||||
"""
|
||||
data, label = self.load_micro_batch()
|
||||
output_tensor = self._call_engine(engine.model, input_tensor, data)
|
||||
output_obj = self._call_engine(engine.model, input_obj, data)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_output_label:
|
||||
return_tensors.append((output_tensor, label))
|
||||
return_tensors.append((output_obj, label))
|
||||
if accum_loss is not None:
|
||||
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
return loss_reduced
|
||||
else:
|
||||
# forward only, it's useless since backward is not needed
|
||||
return output_tensor
|
||||
return output_obj
|
||||
else:
|
||||
assert isinstance(
|
||||
output_tensor,
|
||||
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}'
|
||||
)
|
||||
return output_tensor
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
|
||||
)
|
||||
return output_obj
|
||||
|
||||
def _backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad):
|
||||
def _backward_step(self, engine, input_obj, output_obj, output_obj_grad):
|
||||
"""Backward step through the passed-in output tensor. If it is the last stage, the
|
||||
output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
|
||||
Returns the gradients with respect to the input tensor (None if first stage).
|
||||
This is a helper function and can be ignored by users.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
input_tensor (:class:`torch.Tensor`): input tensor for this pipeline stage.
|
||||
output_tensor (:class:`torch.Tensor`): output tensor for this pipeline stage.
|
||||
output_tensor_grad (:class:`torch.Tensor`): gradient of output tensor for this pipeline stage.
|
||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage.
|
||||
output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage.
|
||||
output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage.
|
||||
|
||||
Returns:
|
||||
:class:`torch.Tensor`: gradient of input tensor.
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor.
|
||||
"""
|
||||
|
||||
# Retain the grad on the input_tensor.
|
||||
if input_tensor is not None:
|
||||
input_tensor.retain_grad()
|
||||
|
||||
# Retain the grad on the input_obj.
|
||||
if input_obj is not None:
|
||||
if isinstance(input_obj, torch.Tensor):
|
||||
input_obj.retain_grad()
|
||||
else:
|
||||
for in_tensor in input_obj:
|
||||
if in_tensor is not None:
|
||||
in_tensor.retain_grad()
|
||||
# Backward pass.
|
||||
if output_tensor_grad is None:
|
||||
engine.backward(output_tensor)
|
||||
if output_obj_grad is None:
|
||||
engine.backward(output_obj)
|
||||
else:
|
||||
engine.backward_by_grad(output_tensor, output_tensor_grad)
|
||||
engine.backward_by_grad(output_obj, output_obj_grad)
|
||||
|
||||
# Collect the grad of the input_tensor.
|
||||
input_tensor_grad = None
|
||||
if input_tensor is not None:
|
||||
input_tensor_grad = input_tensor.grad
|
||||
# Collect the grad of the input_obj.
|
||||
input_obj_grad = None
|
||||
if input_obj is not None:
|
||||
if isinstance(input_obj, torch.Tensor):
|
||||
input_obj_grad = input_obj.grad
|
||||
else:
|
||||
input_obj_grad = []
|
||||
for in_tensor in input_obj:
|
||||
input_obj_grad.append(in_tensor.grad)
|
||||
|
||||
return input_tensor_grad
|
||||
return input_obj_grad
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
|
@ -257,108 +272,113 @@ class PipelineSchedule(BaseSchedule):
|
|||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_tensors = None
|
||||
output_tensors = None
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
if not forward_only:
|
||||
input_tensors = []
|
||||
output_tensors = []
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
return_tensors = []
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
# Used for tensor meta information communication
|
||||
ft_shape = self.tensor_shape
|
||||
bt_shape = None
|
||||
ft_shapes = self.tensor_shape
|
||||
bt_shapes = None
|
||||
fs_checker = self.tensor_shape is None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||
input_tensor = comm.recv_forward(ft_shape,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_tensor = self._forward_step(engine,
|
||||
input_tensor,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
||||
input_obj = comm.recv_forward(ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
if not gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
bt_shape = output_tensor.shape
|
||||
fs_checker = comm.send_tensor_meta(output_tensor, fs_checker)
|
||||
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
bt_shapes = output_obj.shape
|
||||
else:
|
||||
bt_shapes = []
|
||||
for out_tensor in output_obj:
|
||||
bt_shapes.append(out_tensor.shape)
|
||||
fs_checker = comm.send_obj_meta(output_obj, fs_checker)
|
||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not forward_only:
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
if not gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
ft_shape = comm.recv_tensor_meta(ft_shape)
|
||||
input_tensor = comm.recv_forward(ft_shape,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
ft_shapes = comm.recv_obj_meta(ft_shapes)
|
||||
input_obj = comm.recv_forward(ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_tensor = self._forward_step(engine,
|
||||
input_tensor,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
if forward_only:
|
||||
comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if not last_iteration:
|
||||
input_tensor = comm.recv_forward(ft_shape,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj = comm.recv_forward(ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
else:
|
||||
output_tensor_grad = comm.send_forward_recv_backward(output_tensor,
|
||||
bt_shape,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_obj_grad = comm.send_forward_recv_backward(output_obj,
|
||||
bt_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Add input_tensor and output_tensor to end of list.
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Pop input_tensor and output_tensor from the start of the list for
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_tensor = None
|
||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj = None
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
else:
|
||||
input_tensor = comm.send_backward_recv_forward(input_tensor_grad,
|
||||
ft_shape,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_obj = comm.send_backward_recv_forward(input_obj_grad,
|
||||
ft_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_tensor = input_tensors.pop(0)
|
||||
output_tensor = output_tensors.pop(0)
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
output_tensor_grad = comm.recv_backward(bt_shape,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_obj_grad = comm.recv_backward(bt_shapes,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = pack_return_tensors(return_tensors)
|
||||
|
@ -426,45 +446,43 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
def _forward_step(self,
|
||||
engine,
|
||||
model_chunk_id,
|
||||
input_tensor,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=True,
|
||||
accum_loss=None):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
|
||||
is obtained from data_iterator, otherwise the passed-in input_obj is used.
|
||||
Returns output tensor. This is a helper function and can be ignored by users.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
model_chunk_id (int): The id of model chunks.
|
||||
input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage.
|
||||
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
|
||||
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
|
||||
return_output_label (bool, optional): Whether returns output labels.
|
||||
accum_loss (optional): Where accumulated loss stores.
|
||||
Returns:
|
||||
:class:`torch.Tensor`: output or the loss value of the current pipeline stage.
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
|
||||
"""
|
||||
data, label = self.load_micro_batch(model_chunk_id)
|
||||
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
|
||||
output_obj = self._call_engine(engine.model[model_chunk_id], input_obj, data)
|
||||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
if return_output_label:
|
||||
return_tensors.append((output_tensor, label))
|
||||
return_tensors.append((output_obj, label))
|
||||
if accum_loss is not None:
|
||||
loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches
|
||||
loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches
|
||||
accum_loss.add_(loss_reduced.detach())
|
||||
return loss_reduced
|
||||
else:
|
||||
# forward only, it's useless since backward is not needed
|
||||
return output_tensor
|
||||
return output_obj
|
||||
else:
|
||||
assert isinstance(
|
||||
output_tensor,
|
||||
torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).'
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}'
|
||||
)
|
||||
return output_tensor
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
self._logger.debug(
|
||||
f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}'
|
||||
)
|
||||
return output_obj
|
||||
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
|
@ -486,19 +504,19 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
self.load_batch(data_iter)
|
||||
model = engine.model
|
||||
input_tensors = [[] for _ in range(len(model))]
|
||||
output_tensors = [[] for _ in range(len(model))]
|
||||
input_objs = [[] for _ in range(len(model))]
|
||||
output_objs = [[] for _ in range(len(model))]
|
||||
return_tensors = []
|
||||
if not forward_only:
|
||||
output_tensor_grads = [[] for _ in range(len(model))]
|
||||
output_obj_grads = [[] for _ in range(len(model))]
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# Used for tensor meta information communication
|
||||
input_tensor_shapes = [self.tensor_shape for _ in range(len(model))]
|
||||
output_tensor_shapes = [None for _ in range(len(model))]
|
||||
# Used for obj meta information communication
|
||||
input_obj_shapes = [self.tensor_shape for _ in range(len(model))]
|
||||
output_obj_shapes = [None for _ in range(len(model))]
|
||||
send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))]
|
||||
|
||||
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
|
@ -545,24 +563,24 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
|
||||
# forward step
|
||||
if gpc.is_pipeline_first_stage():
|
||||
if len(input_tensors[model_chunk_id]) == \
|
||||
len(output_tensors[model_chunk_id]):
|
||||
input_tensors[model_chunk_id].append(None)
|
||||
input_tensor = input_tensors[model_chunk_id][-1]
|
||||
output_tensor = self._forward_step(engine,
|
||||
model_chunk_id,
|
||||
input_tensor,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
output_tensors[model_chunk_id].append(output_tensor)
|
||||
if len(input_objs[model_chunk_id]) == \
|
||||
len(output_objs[model_chunk_id]):
|
||||
input_objs[model_chunk_id].append(None)
|
||||
input_obj = input_objs[model_chunk_id][-1]
|
||||
output_obj = self._forward_step(engine,
|
||||
model_chunk_id,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
output_objs[model_chunk_id].append(output_obj)
|
||||
|
||||
# if forward-only, no need to save tensors for a backward pass
|
||||
if forward_only:
|
||||
input_tensors[model_chunk_id].pop()
|
||||
output_tensors[model_chunk_id].pop()
|
||||
input_objs[model_chunk_id].pop()
|
||||
output_objs[model_chunk_id].pop()
|
||||
|
||||
return output_tensor
|
||||
return output_obj
|
||||
|
||||
def _backward_step_helper(microbatch_id):
|
||||
"""Helper method to run backward step with model split into chunks
|
||||
|
@ -572,31 +590,35 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
gpc.set_virtual_pipeline_parallel_rank(model_chunk_id)
|
||||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
if len(output_tensor_grads[model_chunk_id]) == 0:
|
||||
output_tensor_grads[model_chunk_id].append(None)
|
||||
input_tensor = input_tensors[model_chunk_id].pop(0)
|
||||
output_tensor = output_tensors[model_chunk_id].pop(0)
|
||||
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
|
||||
input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad)
|
||||
if len(output_obj_grads[model_chunk_id]) == 0:
|
||||
output_obj_grads[model_chunk_id].append(None)
|
||||
input_obj = input_objs[model_chunk_id].pop(0)
|
||||
output_obj = output_objs[model_chunk_id].pop(0)
|
||||
output_obj_grad = output_obj_grads[model_chunk_id].pop(0)
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
return input_tensor_grad
|
||||
return input_obj_grad
|
||||
|
||||
# Run warmup forward passes.
|
||||
gpc.set_virtual_pipeline_parallel_rank(0)
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0])
|
||||
input_tensors[0].append(
|
||||
comm.recv_forward(input_tensor_shapes[0],
|
||||
dtype=self.dtype,
|
||||
input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0])
|
||||
input_objs[0].append(
|
||||
comm.recv_forward(input_obj_shapes[0], dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
|
||||
for k in range(num_warmup_microbatches):
|
||||
model_chunk_id = get_model_chunk_id(k, forward=True)
|
||||
output_tensor = _forward_step_helper(k)
|
||||
output_obj = _forward_step_helper(k)
|
||||
if not gpc.is_pipeline_last_stage():
|
||||
output_tensor_shapes[model_chunk_id] = output_tensor.shape
|
||||
send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor,
|
||||
send_tensor_shape_flags[model_chunk_id])
|
||||
if isinstance(output_obj, torch.Tensor):
|
||||
output_obj_shapes[model_chunk_id] = output_obj.shape
|
||||
else:
|
||||
output_obj_shapes[model_chunk_id] = []
|
||||
for out_tensor in output_obj:
|
||||
output_obj_shapes[model_chunk_id].append(out_tensor.shape)
|
||||
send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj,
|
||||
send_tensor_shape_flags[model_chunk_id])
|
||||
# Determine if tensor should be received from previous stage.
|
||||
next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True)
|
||||
recv_prev = True
|
||||
|
@ -608,65 +630,65 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
|
||||
# Don't send tensor downstream if on last stage.
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor = None
|
||||
output_obj = None
|
||||
|
||||
with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id):
|
||||
if not gpc.is_pipeline_first_stage():
|
||||
input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta(
|
||||
input_tensor_shapes[next_forward_model_chunk_id])
|
||||
input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta(
|
||||
input_obj_shapes[next_forward_model_chunk_id])
|
||||
# Send and receive tensors as appropriate (send tensors computed
|
||||
# in this iteration; receive tensors for next iteration).
|
||||
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
if k == (num_warmup_microbatches - 1) and not forward_only and \
|
||||
not all_warmup_microbatches:
|
||||
input_tensor_grad = None
|
||||
input_obj_grad = None
|
||||
recv_next = True
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
recv_next = False
|
||||
output_shape = output_tensor_shapes[num_model_chunks - 1] if recv_next else None
|
||||
input_tensor, output_tensor_grad = \
|
||||
output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None
|
||||
input_obj, output_obj_grad = \
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
output_obj, input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
|
||||
output_obj_grads[num_model_chunks - 1].append(output_obj_grad)
|
||||
else:
|
||||
input_tensor = \
|
||||
input_obj = \
|
||||
comm.send_forward_recv_forward(
|
||||
output_tensor,
|
||||
output_obj,
|
||||
input_shape,
|
||||
recv_prev=recv_prev,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
input_objs[next_forward_model_chunk_id].append(input_obj)
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for k in range(num_microbatches_remaining):
|
||||
# Forward pass.
|
||||
forward_k = k + num_warmup_microbatches
|
||||
output_tensor = _forward_step_helper(forward_k)
|
||||
output_obj = _forward_step_helper(forward_k)
|
||||
|
||||
# Backward pass.
|
||||
backward_k = k
|
||||
input_tensor_grad = _backward_step_helper(backward_k)
|
||||
input_obj_grad = _backward_step_helper(backward_k)
|
||||
|
||||
# Send output_tensor and input_tensor_grad, receive input_tensor
|
||||
# and output_tensor_grad.
|
||||
# Send output_obj and input_obj_grad, receive input_obj
|
||||
# and output_obj_grad.
|
||||
|
||||
# Determine if current stage has anything to send in either direction,
|
||||
# otherwise set tensor to None.
|
||||
# otherwise set obj to None.
|
||||
forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True)
|
||||
gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id)
|
||||
if gpc.is_pipeline_last_stage():
|
||||
output_tensor = None
|
||||
output_obj = None
|
||||
|
||||
backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False)
|
||||
gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id)
|
||||
if gpc.is_pipeline_first_stage():
|
||||
input_tensor_grad = None
|
||||
input_obj_grad = None
|
||||
|
||||
# Determine if peers are sending, and where in data structure to put
|
||||
# received tensors.
|
||||
|
@ -696,33 +718,33 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
if k == (num_microbatches_remaining - 1):
|
||||
recv_prev = False
|
||||
|
||||
input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
# Communicate tensors.
|
||||
input_tensor, output_tensor_grad = \
|
||||
input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None
|
||||
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
# Communicate objs.
|
||||
input_obj, output_obj_grad = \
|
||||
comm.send_forward_backward_recv_forward_backward(
|
||||
output_tensor, input_tensor_grad,
|
||||
output_obj, input_obj_grad,
|
||||
input_shape,
|
||||
output_shape,
|
||||
recv_prev=recv_prev, recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
scatter_gather_tensors=self.scatter_gather_tensors)
|
||||
|
||||
# Put input_tensor and output_tensor_grad in data structures in the
|
||||
# Put input_obj and output_obj_grad in data structures in the
|
||||
# right location.
|
||||
if recv_prev:
|
||||
input_tensors[next_forward_model_chunk_id].append(input_tensor)
|
||||
input_objs[next_forward_model_chunk_id].append(input_obj)
|
||||
if recv_next:
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
|
||||
output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad)
|
||||
|
||||
# Run cooldown backward passes (flush out pipeline).
|
||||
if not forward_only:
|
||||
if all_warmup_microbatches:
|
||||
output_tensor_grads[num_model_chunks - 1].append(
|
||||
comm.recv_backward(output_tensor_shapes[num_model_chunks - 1],
|
||||
output_obj_grads[num_model_chunks - 1].append(
|
||||
comm.recv_backward(output_obj_shapes[num_model_chunks - 1],
|
||||
scatter_gather_tensors=self.scatter_gather_tensors))
|
||||
for k in range(num_microbatches_remaining, num_microbatches):
|
||||
input_tensor_grad = _backward_step_helper(k)
|
||||
input_obj_grad = _backward_step_helper(k)
|
||||
next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False)
|
||||
recv_next = True
|
||||
if gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
|
@ -730,9 +752,9 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
recv_next = False
|
||||
if k == (num_microbatches - 1):
|
||||
recv_next = False
|
||||
output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
output_tensor_grads[next_backward_model_chunk_id].append(
|
||||
comm.send_backward_recv_backward(input_tensor_grad,
|
||||
output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None
|
||||
output_obj_grads[next_backward_model_chunk_id].append(
|
||||
comm.send_backward_recv_backward(input_obj_grad,
|
||||
output_shape,
|
||||
recv_next=recv_next,
|
||||
dtype=self.dtype,
|
||||
|
|
|
@ -7,9 +7,9 @@ import pytest
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.communication import (recv_backward, recv_forward, recv_tensor_meta, send_backward,
|
||||
from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward,
|
||||
send_backward_recv_forward, send_forward, send_forward_recv_backward,
|
||||
send_tensor_meta)
|
||||
send_obj_meta)
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.initialize import launch
|
||||
|
|
Loading…
Reference in New Issue