from functools import partial from typing import Any, Callable, Iterable, List, Optional, Tuple, Union import torch import torch.cuda import torch.distributed from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager from ._utils import detach, get_batch_size, get_micro_batch, retain_grad, to_device from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: if wait_handles is not None: for req in wait_handles: req.wait() class ZeroBubbleVPipeScheduler(PipelineSchedule): def __init__( self, stage_manager: PipelineStageManager, schedule: List[ScheduledNode], num_model_chunks: int, num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, overlap_p2p: bool = True, ): super().__init__(stage_manager) # batch info self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size self.num_model_chunks = num_model_chunks self.batch: Any self.batch_size: int self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] self.collect_non_loss_data = None self.forward_only = None self.schedules = schedule # TODO: optim post valid self.do_post_validation = False self.is_first_run = True self.optimizer = None # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache # self.send_tensor_metadata = True # self.send_grad_metadata = True # self.tensor_metadata_recv = None # self.grad_metadata_recv = None # P2P communication self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) # init buffer self._free_buffers() def _free_buffers(self): # free local buffer # two dim array, first dim is the model chunk, second dim is the microbatch queue # x & y buffer for schedule b self.input_tensors = [[], []] self.output_tensors = [[], []] # y & dy buffer for schedule w self.output_tensors_dw = [[], []] self.output_tensors_grad_dw = [[], []] # buffer for communication self.send_forward_buffer = [[], []] self.recv_forward_buffer = [[], []] self.send_backward_buffer = [[], []] self.recv_backward_buffer = [[], []] # y buffer for local send fwd self.local_send_forward_buffer = [] # dy buffer for local send bwd self.local_send_backward_buffer = [] def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. Args: data_iter (Iterable): Data iterator. device (Optional[torch.device], optional): Target device. Defaults to None. """ batch = next(data_iter) if device is not None: batch = tree_map(partial(to_device, device=device), batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] self.batch = batch self.batch_size = get_batch_size(batch) if self.microbatch_size is None: assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" self.microbatch_size = self.batch_size // self.num_microbatch if self.num_microbatch is None: assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" self.num_microbatch = self.batch_size // self.microbatch_size if not self.forward_only: assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatch assert ( self.num_microbatch % self.stage_manager.num_stages == 0 ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" if self.forward_only: self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) # if self.batch_size != self.last_batch_size: # self.enable_metadata_cache = False # self.send_tensor_metadata = True # self.send_grad_metadata = True # self.tensor_metadata_recv = None # self.grad_metadata_recv = None self.last_batch_size = self.batch_size def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. Args: microbatch_id (int): the current model chunk idx. Returns: Any: Micro batch. """ assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) self.microbatch_offset[model_chunk_id] += self.microbatch_size return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. Args: microbatch_id (int): the current microbatch idx forward (bool): if is the forward process Returns: int: The model chunk idx of the input microbatch_id """ assert ( microbatch_id < self.num_microbatch * self.num_model_chunks ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not is_forward: # Reverse order model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For ZBV. Args: model_chunk_id (int): The current model chunk idx. prev_rank (int, optional): The rank of the source of the tensor. Returns: Any: The input tensor or input tensor list. Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: ################ # chunk = 0 & is_first_stage # do nothing; cause u are chunk 0 in first rank, u have no prev rank; ################# if self.stage_manager.is_first_stage(ignore_chunk=True): return None, [] ################ # chunk = 0 & not is_first_stage # Recv y from PREV_rank as input ################# else: prev_rank = self.stage_manager.get_prev_rank() input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles else: ################ # chunk = 1 & is_last_stage # do nothing; cause u get y from local_send_forward_buffer in schedule f ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return None, [] ################ # chunk = 1 & not is_last_stage # recv y from NEXT_rank as input ################ else: next_rank = self.stage_manager.get_next_rank() input_tensor, wait_handles = self.comm.recv_forward(next_rank) self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. For ZBV. Args: model_chunk_id (int): The current model chunk idx. next_rank (int, optional): The rank of the source of the tensor. Returns: Any: The input gradient tensor or gradient tensor list. Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: # bwd chunk0 is right V; ################ # chunk = 0 & is_last_stage # do nothing; Already get dy from local_send_backward_buffer in schedule b ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return None, [] ################ # chunk = 0 & not is_last_stage # Recv bwd from next stage; ################ else: next_rank = self.stage_manager.get_next_rank() output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles else: # bwd chunk1 is left V; ################ # chunk = 1 & is_first_stage # do nothing; get loss from local ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return None, [] ################ # chunk = 1 & not first stage # recv_backward recv bwd from prev stage; ################ else: prev_rank = self.stage_manager.get_prev_rank() output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: """Sends the input tensor to the next stage in pipeline. For ZBV. Args: model_chunk_id (int): The current model chunk idx. next_rank (int, optional): The rank of the recipient of the tensor. Returns: Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: ################ # chunk = 0 && is_last_stage # do nothing; hold y on local_send_forward_buffer ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return [] ################ # chunk = 0 && not is_last_stage # self.comm.send_forward send y to NEXT stage ################ else: next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) return send_handles else: ################ # chunk = 1 && is_first_stage # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return [] ################ # chunk = 1 && not is_first_stage # self.comm.send_forward send y to PREV stage ################ else: prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_tensor, prev_rank) return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: """Sends the gradient tensor to the previous stage in pipeline. For ZBV. Args: model_chunk_id (int): The current model chunk idx. prev_rank (int, optional): The rank of the recipient of the tensor Returns: Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: # bwd chunk0 is right V; ################ # chunk = 0 && is_first_stage # do nothing; cause u are the first chunk in first stage; bwd end ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return [] ################ # chunk = 0 && not is_first_stage # Send dx to PREV stage; ################ else: prev_rank = self.stage_manager.get_prev_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) return send_handles # bwd chunk1 is left V; else: # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage} self.send_backward_buffer {self.send_backward_buffer}") ################ # chunk = 1 && is_last_stage # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return [] ################ # chunk = 1 && not is_last_stage # Send dx to NEXT stage; ################ else: next_rank = self.stage_manager.get_next_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward(input_tensor_grad, next_rank) return send_handles def forward_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, input_obj: Optional[dict], criterion: Callable, accum_loss: Optional[torch.Tensor] = None, outputs: Optional[List[Any]] = None, ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; input_obj (Optional[dict]): x; criterion (Callable): loss function; accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ # Load input ids, attention mask and labels # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. # Only attention_mask from micro_batch is used with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate output_obj = model_chunk[model_chunk_id](input_obj) # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): loss = criterion(output_obj) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss else: return output_obj def backward_b_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ) -> Optional[dict]: """Backward dx step of the pipeline; we calculate "dx = w*dy" here; Args: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model input_obj (Optional[dict]): x. output_obj (Union[dict, torch.Tensor]): y. output_obj_grad (dict): dy. Returns: Optional[dict]: dx. """ # calculate bwd b step ; only dx = w*dy; # Retain the grad on the input_obj. tree_map(retain_grad, input_obj) if model_chunk_id == 0: # bwd step torch.autograd.backward( tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True ) else: if self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss torch.autograd.backward(output_obj, inputs=input_obj, retain_graph=True) else: # commom bwd step # BUG:output_obj_grad is None torch.autograd.backward( tensors=output_obj, grad_tensors=output_obj_grad, inputs=input_obj, retain_graph=True ) return input_obj.grad def backward_w_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ): """Backward dw step of the pipeline; we calculate "dw = x*dy" here; Args: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model output_obj (Union[dict, torch.Tensor]): y. output_obj_grad (dict): dy. Returns: Nothing need to return; we only calculate dw then update w; """ # calculate bwd w step ; only dw = x*dy; if model_chunk_id == 0: torch.autograd.backward( tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()) ) else: if self.stage_manager.is_first_stage(ignore_chunk=True): torch.autograd.backward(output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters())) else: torch.autograd.backward( tensors=output_obj, grad_tensors=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()), ) def schedule_f( self, scheduled_node, model_chunk: torch.nn.ModuleList, model_chunk_id: int, criterion: Callable, accum_loss: Optional[torch.Tensor] = None, outputs: Optional[List[Any]] = None, ): """A complete forward schedule; Include recv fwd --> cal fwd --> send fwd; Args: scheduled_node: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; criterion (Callable): loss function; accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. Returns: Nothing. """ # Step1: recv fwd if model_chunk_id == 0: # is first stage; get input from func param if self.stage_manager.is_first_stage(ignore_chunk=True): input_obj = self.load_micro_batch(model_chunk_id=model_chunk_id) else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: # is last stage; recv from local if self.stage_manager.is_last_stage(ignore_chunk=True): input_obj = self.local_send_forward_buffer.pop(0) # not last stage; recv from next else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) # Step2: fwd step output_obj = self.forward_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, input_obj=input_obj, criterion=criterion, accum_loss=accum_loss, outputs=outputs, ) # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) # Step3: send fwd # add output to send_fwd_buffer if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): self.local_send_forward_buffer.append(output_obj) else: self.send_forward_buffer[model_chunk_id].append(output_obj) else: # is first stage; end of fwd; append LOSS to local_send_backward_buffer if self.stage_manager.is_first_stage(ignore_chunk=True): self.local_send_backward_buffer.append(output_obj) else: self.send_forward_buffer[model_chunk_id].append(output_obj) def schedule_b( self, scheduled_node, model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, # input_obj: Optional[dict], # output_obj: Union[dict, torch.Tensor], # output_obj_grad: Optional[dict], ): """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; Args: scheduled_node: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; Returns: Nothing. """ # Step1: recv bwd if model_chunk_id == 0: # chunk0 is last stage; recv output_grad from local_send_backward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): output_tensor_grad = self.local_send_backward_buffer.pop(0) # chunk 0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): output_tensor_grad = self.local_send_backward_buffer.pop(0) # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # print(f"model_chunk_id {model_chunk_id} stage {self.stage_manager.stage}; output_tensor_grad {output_tensor_grad}\n") # get input and output object from buffer; input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # we save loss here self.output_tensors_grad_dw[model_chunk_id].append(output_obj) else: # we save output_tensor_grad here self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) # _wait_p2p(recv_bwd_handles) # Step2: bwd step input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, # optimizer: OptimizerWrapper, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad, ) # print(f"model_chunk_id {model_chunk_id}; stage {self.stage_manager.stage}; input_object_grad {input_object_grad}") # Step3: send bwd if model_chunk_id == 0: # do nothing; end of bwd; if self.stage_manager.is_first_stage(ignore_chunk=True): pass # save input_object_grad to send_backward_buffer else: self.send_backward_buffer[model_chunk_id].append(input_object_grad) else: # send to local_send_backward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): self.local_send_backward_buffer.append(input_object_grad) # send to next else: self.send_backward_buffer[model_chunk_id].append(input_object_grad) def schedule_w( self, scheduled_node, model_chunk: Union[ModuleList, Module], model_chunk_id: int, # optimizer: OptimizerWrapper, ): """A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w); Args: scheduled_node: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; Returns: Nothing. """ # get y & dy from buffer output_obj = self.output_tensors_dw[model_chunk_id].pop(0) output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) self.backward_w_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, # optimizer: OptimizerWrapper, output_obj=output_obj, output_obj_grad=output_obj_grad, ) def run_forward_backward( self, model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, ): """ Runs Zerobubble schedule, with communication between pipeline stages. """ # # prepare batch self.load_batch(data_iter) print( f"self.batch_size {self.batch_size}; self.batch shape {self.batch.shape}; self.num_microbatch {self.num_microbatch}; self.microbatch_size {self.microbatch_size}" ) it = 0 # while we still have schedules_node in self.schedules while it < len(self.schedules): scheduled_node = self.schedules[it] print( f"it {it}; manger_stage {self.stage_manager.stage}; node_stage {scheduled_node.stage} chunk {scheduled_node.chunk} {scheduled_node.type};" ) if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication if scheduled_node.type == "RECV_FORWARD": self.recv_forward(scheduled_node.chunk) elif scheduled_node.type == "RECV_BACKWARD": self.recv_backward(scheduled_node.chunk) elif scheduled_node.type == "SEND_FORWARD": self.send_forward(scheduled_node.chunk) elif scheduled_node.type == "SEND_BACKWARD": self.send_backward(scheduled_node.chunk) if scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, criterion=criterion, accum_loss=return_loss, outputs=return_outputs, ) elif scheduled_node.type == "B": self.schedule_b( scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, ) elif scheduled_node.type == "W": self.schedule_w( scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, ) it += 1