#!/usr/bin/env python # -*- encoding: utf-8 -*- from typing import Iterable, Tuple import torch.cuda import colossalai.legacy.communication.p2p_v2 as comm from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.core import global_context as gpc from colossalai.legacy.engine import Engine from colossalai.utils.cuda import get_current_device from ._pipeline_schedule import PipelineSchedule def pack_return_tensors(return_tensors): output, label = tuple(zip(*return_tensors)) if isinstance(output[0], torch.Tensor): output = torch.cat(output, dim=0) elif isinstance(output[0], (list, tuple)): output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) else: raise TypeError(f"Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) else: merged_label = {k: [] for k in label[0].keys()} for d in label: for k, v in d.items(): merged_label[k].append(v) label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} return output, label class PipelineScheduleV2(PipelineSchedule): """Derived class of PipelineSchedule, the only difference is that forward_backward_step is reconstructed with p2p_v2 Args: num_microbatches (int): The number of microbatches. data_process_func (Callable, optional): The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. tensor_shape (torch.Size, optional): Specified shape in pipeline communication. scatter_gather_tensors (bool, optional): If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. Example: # this shows an example of customized data_process_func def data_process_func(stage_output, dataloader_output): output1, output2 = stage_output item1, item2, item3 = dataloader_output # assume item2 is not needed data = (output1, output2, item1) label = item3 return data, label """ def forward_backward_step( self, engine: Engine, data_iter: Iterable, forward_only=False, return_loss=True, return_output_label=True ) -> Tuple[torch.Tensor]: """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. Args: engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference. data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). forward_only (bool, optional): Whether run forward step only. Default is false. If true, no backward will be run. return_loss (bool, optional): Whether returns the loss value. Default is true. return_output_label (bool, optional): If False, the output and label won't be returned. Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ assert ( forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." self.load_batch(data_iter) # num_warmup_microbatches is the step when not all the processes are working num_warmup_microbatches = ( gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 ) num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches # Input, output tensors only need to be saved when doing backward passes input_objs = None output_objs = None # local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) if not forward_only: input_objs = [] output_objs = [] return_tensors = [] if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): accum_loss = torch.zeros(1, device=get_current_device()) else: accum_loss = None # Run warmup forward passes. for i in range(num_warmup_microbatches): input_obj = comm.recv_forward() output_obj = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss ) comm.send_forward(output_obj) if not forward_only: input_objs.append(input_obj) output_objs.append(output_obj) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: input_obj = comm.recv_forward() # Run 1F1B in steady state. for i in range(num_microbatches_remaining): last_iteration = i == (num_microbatches_remaining - 1) output_obj = self._forward_step( engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss ) if forward_only: comm.send_forward(output_obj) if not last_iteration: input_obj = comm.recv_forward() else: # TODO adjust here comm.send_forward(output_obj) output_obj_grad = comm.recv_backward() # Add input_obj and output_obj to end of list. input_objs.append(input_obj) output_objs.append(output_obj) # Pop output_obj and output_obj from the start of the list for # the backward pass. input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) if last_iteration: input_obj = None comm.send_backward(input_obj_grad) else: input_obj = comm.recv_forward() comm.send_backward(input_obj_grad) # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): input_obj = input_objs.pop(0) output_obj = output_objs.pop(0) output_obj_grad = comm.recv_backward() input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) comm.send_backward(input_obj_grad) if len(return_tensors) > 0: output, label = pack_return_tensors(return_tensors) return output, label, accum_loss else: return None, None, accum_loss