diff --git a/colossalai/communication/__init__.py b/colossalai/communication/__init__.py index b71744d9f..220481b7a 100644 --- a/colossalai/communication/__init__.py +++ b/colossalai/communication/__init__.py @@ -3,7 +3,7 @@ from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_fo send_backward_recv_backward, send_forward_recv_backward, send_forward_backward_recv_forward_backward, recv_forward, recv_backward) from .ring import ring_forward -from .utils import send_tensor_meta, recv_tensor_meta +from .utils import send_obj_meta, recv_obj_meta __all__ = [ 'all_gather', @@ -21,6 +21,6 @@ __all__ = [ 'recv_backward', 'recv_forward', 'ring_forward', - 'send_tensor_meta', - 'recv_tensor_meta', + 'send_obj_meta', + 'recv_obj_meta', ] diff --git a/colossalai/communication/utils.py b/colossalai/communication/utils.py index f57a0009c..ef9eceea8 100644 --- a/colossalai/communication/utils.py +++ b/colossalai/communication/utils.py @@ -9,14 +9,21 @@ from typing import Union, List, Tuple TensorShape = Union[torch.Size, List[int], Tuple[int]] -def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool: - """Sends tensor meta information before sending a specific tensor. - Since the recipient must know the shape of the tensor in p2p communications, - meta information of the tensor should be sent before communications. This function - synchronizes with :func:`recv_tensor_meta`. +def send_meta_helper(obj, next_rank, tensor_kwargs): + send_shape = torch.tensor(obj.size(), **tensor_kwargs) + send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) + dist.send(send_ndims, next_rank) + dist.send(send_shape, next_rank) + + +def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: + """Sends obj meta information before sending a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be sent before communications. This function + synchronizes with :func:`recv_obj_meta`. Args: - tensor (:class:`torch.Tensor`): Tensor to be sent. + obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. need_meta (bool, optional): If False, meta information won't be sent. next_rank (int): The rank of the next member in pipeline parallel group. @@ -28,42 +35,57 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None) -> bool: next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} - - send_shape = torch.tensor(tensor.size(), **tensor_kwargs) - send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) - dist.send(send_ndims, next_rank) - dist.send(send_shape, next_rank) + if isinstance(obj, torch.Tensor): + send_obj_nums = torch.tensor(1, **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + send_meta_helper(obj, next_rank, tensor_kwargs) + else: + send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + for tensor_to_send in obj: + send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) return False -def recv_tensor_meta(tensor_shape: TensorShape, prev_rank=None) -> torch.Size: - """Receives tensor meta information before receiving a specific tensor. - Since the recipient must know the shape of the tensor in p2p communications, - meta information of the tensor should be received before communications. This function - synchronizes with :func:`send_tensor_meta`. +def recv_meta_helper(prev_rank, tensor_kwargs): + recv_ndims = torch.empty((), **tensor_kwargs) + dist.recv(recv_ndims, prev_rank) + recv_shape = torch.empty(recv_ndims, **tensor_kwargs) + dist.recv(recv_shape, prev_rank) + return recv_shape + + +def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: + """Receives obj meta information before receiving a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be received before communications. This function + synchronizes with :func:`send_obj_meta`. Args: - tensor_shape (:class:`torch.Size`): The shape of the tensor to be received. - prev_rank (int): The rank of the source of the tensor. + obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. + prev_rank (int): The rank of the source of the obj. Returns: - :class:`torch.Size`: The shape of the tensor to be received. + Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. """ - if tensor_shape is None: + if obj_shape is None: if prev_rank is None: prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} + recv_obj_nums = torch.empty((), **tensor_kwargs) + dist.recv(recv_obj_nums, prev_rank) + if recv_obj_nums.item() == 1: + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape = torch.Size(recv_shape) + else: + obj_shape = [] + for i in range(recv_obj_nums.item()): + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape.append(torch.Size(recv_shape)) - recv_ndims = torch.empty((), **tensor_kwargs) - dist.recv(recv_ndims, prev_rank) - recv_shape = torch.empty(recv_ndims, **tensor_kwargs) - dist.recv(recv_shape, prev_rank) - - tensor_shape = torch.Size(recv_shape) - - return tensor_shape + return obj_shape def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 8c408b436..5a6b9597f 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -130,7 +130,7 @@ class PipelineSchedule(BaseSchedule): assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' @staticmethod - def _call_engine(model, input_tensor, batch_data): + def _call_engine(model, input_obj, batch_data): if isinstance(model, NaiveAMPModel): sig = inspect.signature(model.model.forward) elif hasattr(model, 'colo_attr'): @@ -140,16 +140,22 @@ class PipelineSchedule(BaseSchedule): if isinstance(batch_data, torch.Tensor): for p in sig.parameters.values(): if p.kind == inspect.Parameter.VAR_KEYWORD: - if input_tensor is None: + if input_obj is None: return model(batch_data) else: - return model(input_tensor) - if input_tensor is None: + return model(input_obj) + if input_obj is None: return model(batch_data) - elif len(sig.parameters) > 1: - return model(input_tensor, batch_data) + elif isinstance(input_obj, torch.Tensor): + if len(sig.parameters) > 1: + return model(input_obj, batch_data) + else: + return model(input_obj) else: - return model(input_tensor) + if len(sig.parameters) > len(input_obj): + return model(*input_obj, batch_data) + else: + return model(*input_obj) else: filter_batch = True for p in sig.parameters.values(): @@ -157,79 +163,88 @@ class PipelineSchedule(BaseSchedule): filter_batch = False if filter_batch: batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters} - if input_tensor is None and filter_batch: + if input_obj is None and filter_batch: return model(**batch_data) + elif isinstance(input_obj, torch.Tensor) or input_obj is None: + return model(input_obj, **batch_data) else: - return model(input_tensor, **batch_data) + return model(*input_obj, **batch_data) - def _forward_step(self, engine, input_tensor, return_tensors, return_output_label=True, accum_loss=None): + def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None): """Forward step for passed-in model. If it is the first stage, the input tensor - is obtained from data_iterator, otherwise the passed-in input_tensor is used. + is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. Args: engine (colossalai.engine.Engine): Colossalai engine for training and inference. - input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_output_label (bool, optional): Whether returns output labels. accum_loss (optional): Where accumulated loss stores. Returns: - :class:`torch.Tensor`: output or the loss value of the current pipeline stage. + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. """ data, label = self.load_micro_batch() - output_tensor = self._call_engine(engine.model, input_tensor, data) + output_obj = self._call_engine(engine.model, input_obj, data) if gpc.is_last_rank(ParallelMode.PIPELINE): if return_output_label: - return_tensors.append((output_tensor, label)) + return_tensors.append((output_obj, label)) if accum_loss is not None: - loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches accum_loss.add_(loss_reduced.detach()) return loss_reduced else: # forward only, it's useless since backward is not needed - return output_tensor + return output_obj else: - assert isinstance( - output_tensor, - torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' - self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}' - ) - return output_tensor + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + ) + return output_obj - def _backward_step(self, engine, input_tensor, output_tensor, output_tensor_grad): + def _backward_step(self, engine, input_obj, output_obj, output_obj_grad): """Backward step through the passed-in output tensor. If it is the last stage, the - output_tensor_grad is None, otherwise it is the gradients with respect to stage's output tensor. + output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor. Returns the gradients with respect to the input tensor (None if first stage). This is a helper function and can be ignored by users. Args: engine (colossalai.engine.Engine): Colossalai engine for training and inference. - input_tensor (:class:`torch.Tensor`): input tensor for this pipeline stage. - output_tensor (:class:`torch.Tensor`): output tensor for this pipeline stage. - output_tensor_grad (:class:`torch.Tensor`): gradient of output tensor for this pipeline stage. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): input tensor for this pipeline stage. + output_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): output tensor for this pipeline stage. + output_obj_grad (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): gradient of output tensor for this pipeline stage. Returns: - :class:`torch.Tensor`: gradient of input tensor. + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: gradient of input tensor. """ - # Retain the grad on the input_tensor. - if input_tensor is not None: - input_tensor.retain_grad() - + # Retain the grad on the input_obj. + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj.retain_grad() + else: + for in_tensor in input_obj: + if in_tensor is not None: + in_tensor.retain_grad() # Backward pass. - if output_tensor_grad is None: - engine.backward(output_tensor) + if output_obj_grad is None: + engine.backward(output_obj) else: - engine.backward_by_grad(output_tensor, output_tensor_grad) + engine.backward_by_grad(output_obj, output_obj_grad) - # Collect the grad of the input_tensor. - input_tensor_grad = None - if input_tensor is not None: - input_tensor_grad = input_tensor.grad + # Collect the grad of the input_obj. + input_obj_grad = None + if input_obj is not None: + if isinstance(input_obj, torch.Tensor): + input_obj_grad = input_obj.grad + else: + input_obj_grad = [] + for in_tensor in input_obj: + input_obj_grad.append(in_tensor.grad) - return input_tensor_grad + return input_obj_grad def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. @@ -257,108 +272,113 @@ class PipelineSchedule(BaseSchedule): num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches # Input, output tensors only need to be saved when doing backward passes - input_tensors = None - output_tensors = None + input_objs = None + output_objs = None if not forward_only: - input_tensors = [] - output_tensors = [] + input_objs = [] + output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None # Used for tensor meta information communication - ft_shape = self.tensor_shape - bt_shape = None + ft_shapes = self.tensor_shape + bt_shapes = None fs_checker = self.tensor_shape is None # Run warmup forward passes. for i in range(num_warmup_microbatches): if not gpc.is_first_rank(ParallelMode.PIPELINE): - ft_shape = comm.recv_tensor_meta(ft_shape) - input_tensor = comm.recv_forward(ft_shape, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) - output_tensor = self._forward_step(engine, - input_tensor, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) if not gpc.is_last_rank(ParallelMode.PIPELINE): - bt_shape = output_tensor.shape - fs_checker = comm.send_tensor_meta(output_tensor, fs_checker) - comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) + if isinstance(output_obj, torch.Tensor): + bt_shapes = output_obj.shape + else: + bt_shapes = [] + for out_tensor in output_obj: + bt_shapes.append(out_tensor.shape) + fs_checker = comm.send_obj_meta(output_obj, fs_checker) + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) if not forward_only: - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) + input_objs.append(input_obj) + output_objs.append(output_obj) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: if not gpc.is_first_rank(ParallelMode.PIPELINE): - ft_shape = comm.recv_tensor_meta(ft_shape) - input_tensor = comm.recv_forward(ft_shape, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + ft_shapes = comm.recv_obj_meta(ft_shapes) + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): last_iteration = (i == (num_microbatches_remaining - 1)) - output_tensor = self._forward_step(engine, - input_tensor, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) if forward_only: - comm.send_forward(output_tensor, scatter_gather_tensors=self.scatter_gather_tensors) + comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors) if not last_iteration: - input_tensor = comm.recv_forward(ft_shape, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.recv_forward(ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) else: - output_tensor_grad = comm.send_forward_recv_backward(output_tensor, - bt_shape, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grad = comm.send_forward_recv_backward(output_obj, + bt_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) - # Add input_tensor and output_tensor to end of list. - input_tensors.append(input_tensor) - output_tensors.append(output_tensor) + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) - # Pop input_tensor and output_tensor from the start of the list for + # Pop output_obj and output_obj from the start of the list for # the backward pass. - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) - input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad) + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) if last_iteration: - input_tensor = None - comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = None + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) else: - input_tensor = comm.send_backward_recv_forward(input_tensor_grad, - ft_shape, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + input_obj = comm.send_backward_recv_forward(input_obj_grad, + ft_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): - input_tensor = input_tensors.pop(0) - output_tensor = output_tensors.pop(0) + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) - output_tensor_grad = comm.recv_backward(bt_shape, - dtype=self.dtype, - scatter_gather_tensors=self.scatter_gather_tensors) + output_obj_grad = comm.recv_backward(bt_shapes, + dtype=self.dtype, + scatter_gather_tensors=self.scatter_gather_tensors) - input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad) + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - comm.send_backward(input_tensor_grad, scatter_gather_tensors=self.scatter_gather_tensors) + comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) if len(return_tensors) > 0: output, label = pack_return_tensors(return_tensors) @@ -426,45 +446,43 @@ class InterleavedPipelineSchedule(PipelineSchedule): def _forward_step(self, engine, model_chunk_id, - input_tensor, + input_obj, return_tensors, return_output_label=True, accum_loss=None): """Forward step for passed-in model. If it is the first stage, the input tensor - is obtained from data_iterator, otherwise the passed-in input_tensor is used. + is obtained from data_iterator, otherwise the passed-in input_obj is used. Returns output tensor. This is a helper function and can be ignored by users. Args: engine (colossalai.engine.Engine): Colossalai engine for training and inference. model_chunk_id (int): The id of model chunks. - input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage. + input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage. return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return. return_output_label (bool, optional): Whether returns output labels. accum_loss (optional): Where accumulated loss stores. Returns: - :class:`torch.Tensor`: output or the loss value of the current pipeline stage. + Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage. """ data, label = self.load_micro_batch(model_chunk_id) - output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data) + output_obj = self._call_engine(engine.model[model_chunk_id], input_obj, data) if gpc.is_pipeline_last_stage(): if return_output_label: - return_tensors.append((output_tensor, label)) + return_tensors.append((output_obj, label)) if accum_loss is not None: - loss_reduced = self._call_engine_criterion(engine, output_tensor, label) / self.num_microbatches + loss_reduced = self._call_engine_criterion(engine, output_obj, label) / self.num_microbatches accum_loss.add_(loss_reduced.detach()) return loss_reduced else: # forward only, it's useless since backward is not needed - return output_tensor + return output_obj else: - assert isinstance( - output_tensor, - torch.Tensor), 'Output of model using pipeline parallelism must be a tensor (except the last stage).' - self._logger.debug( - f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_tensor.shape}, dtype {output_tensor.dtype}' - ) - return output_tensor + if isinstance(output_obj, torch.Tensor): + self._logger.debug( + f'Global rank {gpc.get_global_rank()}, pipeline rank {gpc.get_local_rank(ParallelMode.PIPELINE)} forward output tensor {output_obj.shape}, dtype {output_obj.dtype}' + ) + return output_obj def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Run interleaved 1F1B schedule (model split into model chunks), with @@ -486,19 +504,19 @@ class InterleavedPipelineSchedule(PipelineSchedule): 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' self.load_batch(data_iter) model = engine.model - input_tensors = [[] for _ in range(len(model))] - output_tensors = [[] for _ in range(len(model))] + input_objs = [[] for _ in range(len(model))] + output_objs = [[] for _ in range(len(model))] return_tensors = [] if not forward_only: - output_tensor_grads = [[] for _ in range(len(model))] + output_obj_grads = [[] for _ in range(len(model))] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None - # Used for tensor meta information communication - input_tensor_shapes = [self.tensor_shape for _ in range(len(model))] - output_tensor_shapes = [None for _ in range(len(model))] + # Used for obj meta information communication + input_obj_shapes = [self.tensor_shape for _ in range(len(model))] + output_obj_shapes = [None for _ in range(len(model))] send_tensor_shape_flags = [self.tensor_shape is None for _ in range(len(model))] pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) @@ -545,24 +563,24 @@ class InterleavedPipelineSchedule(PipelineSchedule): # forward step if gpc.is_pipeline_first_stage(): - if len(input_tensors[model_chunk_id]) == \ - len(output_tensors[model_chunk_id]): - input_tensors[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id][-1] - output_tensor = self._forward_step(engine, - model_chunk_id, - input_tensor, - return_tensors, - return_output_label=return_output_label, - accum_loss=accum_loss) - output_tensors[model_chunk_id].append(output_tensor) + if len(input_objs[model_chunk_id]) == \ + len(output_objs[model_chunk_id]): + input_objs[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id][-1] + output_obj = self._forward_step(engine, + model_chunk_id, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + output_objs[model_chunk_id].append(output_obj) # if forward-only, no need to save tensors for a backward pass if forward_only: - input_tensors[model_chunk_id].pop() - output_tensors[model_chunk_id].pop() + input_objs[model_chunk_id].pop() + output_objs[model_chunk_id].pop() - return output_tensor + return output_obj def _backward_step_helper(microbatch_id): """Helper method to run backward step with model split into chunks @@ -572,31 +590,35 @@ class InterleavedPipelineSchedule(PipelineSchedule): gpc.set_virtual_pipeline_parallel_rank(model_chunk_id) if gpc.is_pipeline_last_stage(): - if len(output_tensor_grads[model_chunk_id]) == 0: - output_tensor_grads[model_chunk_id].append(None) - input_tensor = input_tensors[model_chunk_id].pop(0) - output_tensor = output_tensors[model_chunk_id].pop(0) - output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) - input_tensor_grad = self._backward_step(engine, input_tensor, output_tensor, output_tensor_grad) + if len(output_obj_grads[model_chunk_id]) == 0: + output_obj_grads[model_chunk_id].append(None) + input_obj = input_objs[model_chunk_id].pop(0) + output_obj = output_objs[model_chunk_id].pop(0) + output_obj_grad = output_obj_grads[model_chunk_id].pop(0) + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) - return input_tensor_grad + return input_obj_grad # Run warmup forward passes. gpc.set_virtual_pipeline_parallel_rank(0) if not gpc.is_pipeline_first_stage(): - input_tensor_shapes[0] = comm.recv_tensor_meta(input_tensor_shapes[0]) - input_tensors[0].append( - comm.recv_forward(input_tensor_shapes[0], - dtype=self.dtype, + input_obj_shapes[0] = comm.recv_obj_meta(input_obj_shapes[0]) + input_objs[0].append( + comm.recv_forward(input_obj_shapes[0], dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors)) for k in range(num_warmup_microbatches): model_chunk_id = get_model_chunk_id(k, forward=True) - output_tensor = _forward_step_helper(k) + output_obj = _forward_step_helper(k) if not gpc.is_pipeline_last_stage(): - output_tensor_shapes[model_chunk_id] = output_tensor.shape - send_tensor_shape_flags[model_chunk_id] = comm.send_tensor_meta(output_tensor, - send_tensor_shape_flags[model_chunk_id]) + if isinstance(output_obj, torch.Tensor): + output_obj_shapes[model_chunk_id] = output_obj.shape + else: + output_obj_shapes[model_chunk_id] = [] + for out_tensor in output_obj: + output_obj_shapes[model_chunk_id].append(out_tensor.shape) + send_tensor_shape_flags[model_chunk_id] = comm.send_obj_meta(output_obj, + send_tensor_shape_flags[model_chunk_id]) # Determine if tensor should be received from previous stage. next_forward_model_chunk_id = get_model_chunk_id(k + 1, forward=True) recv_prev = True @@ -608,65 +630,65 @@ class InterleavedPipelineSchedule(PipelineSchedule): # Don't send tensor downstream if on last stage. if gpc.is_pipeline_last_stage(): - output_tensor = None + output_obj = None with switch_virtual_pipeline_parallel_rank(next_forward_model_chunk_id): if not gpc.is_pipeline_first_stage(): - input_tensor_shapes[next_forward_model_chunk_id] = comm.recv_tensor_meta( - input_tensor_shapes[next_forward_model_chunk_id]) + input_obj_shapes[next_forward_model_chunk_id] = comm.recv_obj_meta( + input_obj_shapes[next_forward_model_chunk_id]) # Send and receive tensors as appropriate (send tensors computed # in this iteration; receive tensors for next iteration). - input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None if k == (num_warmup_microbatches - 1) and not forward_only and \ not all_warmup_microbatches: - input_tensor_grad = None + input_obj_grad = None recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): recv_next = False - output_shape = output_tensor_shapes[num_model_chunks - 1] if recv_next else None - input_tensor, output_tensor_grad = \ + output_shape = output_obj_shapes[num_model_chunks - 1] if recv_next else None + input_obj, output_obj_grad = \ comm.send_forward_backward_recv_forward_backward( - output_tensor, input_tensor_grad, + output_obj, input_obj_grad, input_shape, output_shape, recv_prev=recv_prev, recv_next=recv_next, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) - output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) + output_obj_grads[num_model_chunks - 1].append(output_obj_grad) else: - input_tensor = \ + input_obj = \ comm.send_forward_recv_forward( - output_tensor, + output_obj, input_shape, recv_prev=recv_prev, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) - input_tensors[next_forward_model_chunk_id].append(input_tensor) + input_objs[next_forward_model_chunk_id].append(input_obj) # Run 1F1B in steady state. for k in range(num_microbatches_remaining): # Forward pass. forward_k = k + num_warmup_microbatches - output_tensor = _forward_step_helper(forward_k) + output_obj = _forward_step_helper(forward_k) # Backward pass. backward_k = k - input_tensor_grad = _backward_step_helper(backward_k) + input_obj_grad = _backward_step_helper(backward_k) - # Send output_tensor and input_tensor_grad, receive input_tensor - # and output_tensor_grad. + # Send output_obj and input_obj_grad, receive input_obj + # and output_obj_grad. # Determine if current stage has anything to send in either direction, - # otherwise set tensor to None. + # otherwise set obj to None. forward_model_chunk_id = get_model_chunk_id(forward_k, forward=True) gpc.set_virtual_pipeline_parallel_rank(forward_model_chunk_id) if gpc.is_pipeline_last_stage(): - output_tensor = None + output_obj = None backward_model_chunk_id = get_model_chunk_id(backward_k, forward=False) gpc.set_virtual_pipeline_parallel_rank(backward_model_chunk_id) if gpc.is_pipeline_first_stage(): - input_tensor_grad = None + input_obj_grad = None # Determine if peers are sending, and where in data structure to put # received tensors. @@ -696,33 +718,33 @@ class InterleavedPipelineSchedule(PipelineSchedule): if k == (num_microbatches_remaining - 1): recv_prev = False - input_shape = input_tensor_shapes[next_forward_model_chunk_id] if recv_prev else None - output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None - # Communicate tensors. - input_tensor, output_tensor_grad = \ + input_shape = input_obj_shapes[next_forward_model_chunk_id] if recv_prev else None + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + # Communicate objs. + input_obj, output_obj_grad = \ comm.send_forward_backward_recv_forward_backward( - output_tensor, input_tensor_grad, + output_obj, input_obj_grad, input_shape, output_shape, recv_prev=recv_prev, recv_next=recv_next, dtype=self.dtype, scatter_gather_tensors=self.scatter_gather_tensors) - # Put input_tensor and output_tensor_grad in data structures in the + # Put input_obj and output_obj_grad in data structures in the # right location. if recv_prev: - input_tensors[next_forward_model_chunk_id].append(input_tensor) + input_objs[next_forward_model_chunk_id].append(input_obj) if recv_next: - output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) + output_obj_grads[next_backward_model_chunk_id].append(output_obj_grad) # Run cooldown backward passes (flush out pipeline). if not forward_only: if all_warmup_microbatches: - output_tensor_grads[num_model_chunks - 1].append( - comm.recv_backward(output_tensor_shapes[num_model_chunks - 1], + output_obj_grads[num_model_chunks - 1].append( + comm.recv_backward(output_obj_shapes[num_model_chunks - 1], scatter_gather_tensors=self.scatter_gather_tensors)) for k in range(num_microbatches_remaining, num_microbatches): - input_tensor_grad = _backward_step_helper(k) + input_obj_grad = _backward_step_helper(k) next_backward_model_chunk_id = get_model_chunk_id(k + 1, forward=False) recv_next = True if gpc.is_pipeline_last_stage(ignore_virtual=True): @@ -730,9 +752,9 @@ class InterleavedPipelineSchedule(PipelineSchedule): recv_next = False if k == (num_microbatches - 1): recv_next = False - output_shape = output_tensor_shapes[next_backward_model_chunk_id] if recv_next else None - output_tensor_grads[next_backward_model_chunk_id].append( - comm.send_backward_recv_backward(input_tensor_grad, + output_shape = output_obj_shapes[next_backward_model_chunk_id] if recv_next else None + output_obj_grads[next_backward_model_chunk_id].append( + comm.send_backward_recv_backward(input_obj_grad, output_shape, recv_next=recv_next, dtype=self.dtype, diff --git a/tests/test_trainer/test_pipeline/test_p2p.py b/tests/test_trainer/test_pipeline/test_p2p.py index 5b9d80453..72820c6a1 100644 --- a/tests/test_trainer/test_pipeline/test_p2p.py +++ b/tests/test_trainer/test_pipeline/test_p2p.py @@ -7,9 +7,9 @@ import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp -from colossalai.communication import (recv_backward, recv_forward, recv_tensor_meta, send_backward, +from colossalai.communication import (recv_backward, recv_forward, recv_obj_meta, send_backward, send_backward_recv_forward, send_forward, send_forward_recv_backward, - send_tensor_meta) + send_obj_meta) from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.initialize import launch