from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.cuda from torch.nn import Module, ModuleList from torch.utils._pytree import tree_map from colossalai.accelerator import get_accelerator from colossalai.interface import OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.schedule.v_schedule import ScheduledNode from colossalai.pipeline.stage_manager import PipelineStageManager from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, retain_grad, to_device from .base import PipelineSchedule AUTO_SCHEDULE_COMMUNICATION_TYPES = {"RECV_FORWARD", "RECV_BACKWARD", "SEND_FORWARD", "SEND_BACKWARD"} def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None: if wait_handles is not None: for req in wait_handles: req.wait() def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): """Pseudo-deallocate (i.e., set to scalar) the output tensor's '.data' field. This method should be called right after the output tensor has been sent to the next pipeline stage. At this point, the output tensor is only useful for its '.grad_fn' field, and not its '.data'. """ if (out is None) or (not deallocate_pipeline_outputs): return assert isinstance(out, torch.Tensor), "expected Tensor, found %s." % type(out).__name__ assert out._base is None, "counter-productive to free a view of another tensor." # out.data = torch.empty((1,), device=out.device, dtype=out.dtype,) out.data.untyped_storage().resize_(0) class ZeroBubbleVPipeScheduler(PipelineSchedule): def __init__( self, stage_manager: PipelineStageManager, schedule: List[ScheduledNode], num_model_chunks: int, num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, enable_metadata_cache: bool = True, overlap_p2p: bool = True, ): super().__init__(stage_manager) # batch info self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size self.num_model_chunks = num_model_chunks self.batch: Any self.batch_size: int self.last_batch_size: Optional[int] = None self.microbatch_offset: List[int] self.schedules = schedule # TODO: optim post valid self.do_post_validation = False # P2PMeta cache # self.enable_metadata_cache = enable_metadata_cache # self.send_tensor_metadata = True # self.send_grad_metadata = True # self.tensor_metadata_recv = None # self.grad_metadata_recv = None # P2P communication self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) # init communication map self.communication_map = { "SEND_FORWARD": self.send_forward, "RECV_FORWARD": self.recv_forward, "SEND_BACKWARD": self.send_backward, "RECV_BACKWARD": self.recv_backward, } # init buffer self._free_buffers() def _free_buffers(self): # free local buffer # two dim array, first dim is the model chunk, second dim is the microbatch queue # x & y buffer for schedule b self.input_tensors = [[], []] self.output_tensors = [[], []] # y & dy buffer for schedule w self.output_tensors_dw = [[], []] self.output_tensors_grad_dw = [[], []] # buffer for communication self.send_forward_buffer = [[], []] self.recv_forward_buffer = [[], []] self.send_backward_buffer = [[], []] self.recv_backward_buffer = [[], []] # y buffer for local send fwd self.local_send_forward_buffer = [] # dy buffer for local send bwd self.local_send_backward_buffer = [] def assert_buffer_empty(self): # assert buuffer is empty at end assert len(self.input_tensors[0]) == 0 assert len(self.input_tensors[1]) == 0 assert len(self.output_tensors[0]) == 0 assert len(self.output_tensors[1]) == 0 assert len(self.output_tensors_dw[0]) == 0 assert len(self.output_tensors_dw[1]) == 0 assert len(self.output_tensors_grad_dw[0]) == 0 assert len(self.output_tensors_grad_dw[1]) == 0 assert len(self.send_forward_buffer[0]) == 0 assert len(self.send_forward_buffer[1]) == 0 assert len(self.recv_forward_buffer[0]) == 0 assert len(self.recv_forward_buffer[1]) == 0 assert len(self.send_backward_buffer[0]) == 0 assert len(self.send_backward_buffer[1]) == 0 assert len(self.recv_backward_buffer[0]) == 0 assert len(self.recv_backward_buffer[1]) == 0 assert len(self.local_send_forward_buffer) == 0 assert len(self.local_send_backward_buffer) == 0 def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None: """Load a batch from data iterator. Args: data_iter (Iterable): Data iterator. device (Optional[torch.device], optional): Target device. Defaults to None. """ batch = next(data_iter) if device is not None: batch = tree_map(partial(to_device, device=device), batch) self.microbatch_offset = [0 for _ in range(self.num_model_chunks)] self.batch = batch self.batch_size = get_batch_size(batch) if self.microbatch_size is None: assert self.batch_size % self.num_microbatch == 0, "Batch size should divided by the number of microbatch" self.microbatch_size = self.batch_size // self.num_microbatch if self.num_microbatch is None: assert self.batch_size % self.microbatch_size == 0, "Batch size should divided by the microbatch size" self.num_microbatch = self.batch_size // self.microbatch_size if not self.forward_only: assert self.last_batch_size is None or self.last_batch_size == self.batch_size assert self.batch_size == self.microbatch_size * self.num_microbatch assert ( self.num_microbatch % self.stage_manager.num_stages == 0 ), "Number of microbatch should be an integer multiple of number of pipeline parallel devices" if self.forward_only: self.num_microbatch = (self.batch_size - 1) // self.microbatch_size + 1 # NOTE: disable metadata cache when batch size changes (not valid anymore) # if self.batch_size != self.last_batch_size: # self.enable_metadata_cache = False # self.send_tensor_metadata = True # self.send_grad_metadata = True # self.tensor_metadata_recv = None # self.grad_metadata_recv = None self.last_batch_size = self.batch_size def load_micro_batch(self, model_chunk_id: int) -> Any: """Load a micro batch from the current batch. Args: microbatch_id (int): the current model chunk idx. Returns: Any: Micro batch. """ assert self.microbatch_offset[model_chunk_id] <= self.batch_size, "Microbatches exhausted" micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) self.microbatch_offset[model_chunk_id] += self.microbatch_size return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch) def get_model_chunk_id(self, microbatch_id: int, is_forward: bool) -> int: """Helper method to get the model chunk ID given the iteration number. Args: microbatch_id (int): the current microbatch idx forward (bool): if is the forward process Returns: int: The model chunk idx of the input microbatch_id """ assert ( microbatch_id < self.num_microbatch * self.num_model_chunks ), f"microbatch_id {microbatch_id} is out of range ({self.num_microbatch * self.num_model_chunks})" microbatch_id_in_group = microbatch_id % (self.stage_manager.num_stages * self.num_model_chunks) model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages if not is_forward: # Reverse order model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For ZBV. Args: model_chunk_id (int): The current model chunk idx. prev_rank (int, optional): The rank of the source of the tensor. Returns: Any: The input tensor or input tensor list. Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: ################ # chunk = 0 & is_first_stage # do nothing; cause u are chunk 0 in first rank, u have no prev rank; ################# if self.stage_manager.is_first_stage(ignore_chunk=True): return None, [] ################ # chunk = 0 & not is_first_stage # Recv y from PREV_rank as input ################# else: prev_rank = self.stage_manager.get_prev_rank() input_tensor, wait_handles = self.comm.recv_forward(prev_rank=prev_rank) self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles else: ################ # chunk = 1 & is_last_stage # do nothing; cause u get y from local_send_forward_buffer in schedule f ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return None, [] ################ # chunk = 1 & not is_last_stage # recv y from NEXT_rank as input ################ else: next_rank = self.stage_manager.get_next_rank() input_tensor, wait_handles = self.comm.recv_forward(next_rank) self.recv_forward_buffer[model_chunk_id].append(input_tensor) return input_tensor, wait_handles def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Tuple[Any, List]: """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage. For ZBV. Args: model_chunk_id (int): The current model chunk idx. next_rank (int, optional): The rank of the source of the tensor. Returns: Any: The input gradient tensor or gradient tensor list. Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: # bwd chunk0 is right V; ################ # chunk = 0 & is_last_stage # do nothing; Already get dy from local_send_backward_buffer in schedule b ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return None, [] ################ # chunk = 0 & not is_last_stage # Recv bwd from next stage; ################ else: next_rank = self.stage_manager.get_next_rank() output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles else: # bwd chunk1 is left V; ################ # chunk = 1 & is_first_stage # do nothing; get loss from local ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return None, [] ################ # chunk = 1 & not first stage # recv_backward recv bwd from prev stage; ################ else: prev_rank = self.stage_manager.get_prev_rank() output_tensor_grad, wait_handles = self.comm.recv_backward(next_rank=prev_rank) self.recv_backward_buffer[model_chunk_id].append(output_tensor_grad) return output_tensor_grad, wait_handles def send_forward(self, model_chunk_id: int, next_rank: int = None) -> List: """Sends the input tensor to the next stage in pipeline. For ZBV. Args: model_chunk_id (int): The current model chunk idx. next_rank (int, optional): The rank of the recipient of the tensor. Returns: Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: ################ # chunk = 0 && is_last_stage # do nothing; hold y on local_send_forward_buffer ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return [] ################ # chunk = 0 && not is_last_stage # self.comm.send_forward send y to NEXT stage ################ else: next_rank = self.stage_manager.get_next_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_object=output_tensor, next_rank=next_rank) return send_handles else: ################ # chunk = 1 && is_first_stage # do nothing; Already send LOSS to local_send_backward_buffer in schedule f send part ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return [] ################ # chunk = 1 && not is_first_stage # self.comm.send_forward send y to PREV stage ################ else: prev_rank = self.stage_manager.get_prev_rank() output_tensor = self.send_forward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_forward(output_tensor, prev_rank) return send_handles def send_backward(self, model_chunk_id: int, prev_rank: int = None) -> List: """Sends the gradient tensor to the previous stage in pipeline. For ZBV. Args: model_chunk_id (int): The current model chunk idx. prev_rank (int, optional): The rank of the recipient of the tensor Returns: Any: The wait handles for the communication. """ with self.stage_manager.switch_model_chunk_id(model_chunk_id): if model_chunk_id == 0: # bwd chunk0 is right V; ################ # chunk = 0 && is_first_stage # do nothing; cause u are the first chunk in first stage; bwd end ################ if self.stage_manager.is_first_stage(ignore_chunk=True): return [] ################ # chunk = 0 && not is_first_stage # Send dx to PREV stage; ################ else: prev_rank = self.stage_manager.get_prev_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward(input_tensor_grad, prev_rank) return send_handles # bwd chunk1 is left V; else: ################ # chunk = 1 && is_last_stage # do nothing; Already send input_tensor_grad to local_send_bwd_buffer in schedule b; ################ if self.stage_manager.is_last_stage(ignore_chunk=True): return [] ################ # chunk = 1 && not is_last_stage # Send dx to NEXT stage; ################ else: next_rank = self.stage_manager.get_next_rank() input_tensor_grad = self.send_backward_buffer[model_chunk_id].pop(0) send_handles = self.comm.send_backward(input_tensor_grad, next_rank) return send_handles def forward_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, input_obj: Optional[dict], criterion: Callable, accum_loss: Optional[torch.Tensor] = None, outputs: Optional[List[Any]] = None, ) -> Union[torch.Tensor, dict]: """Forward one step of the pipeline Args: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; input_obj (Optional[dict]): x; criterion (Callable): loss function; accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. Returns: Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor). """ # Load input ids, attention mask and labels # micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # for the first stage, input_obj is None # for other stages, input_obj is the output of the previous/next stage containing hidden_states etc. # Only attention_mask from micro_batch is used with self.stage_manager.switch_model_chunk_id(model_chunk_id): # fwd calculate output_obj = model_chunk[model_chunk_id](input_obj) # last layer in model if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): loss = criterion(output_obj) / self.num_microbatch if accum_loss is not None: accum_loss.add_(loss.detach()) if outputs is not None: outputs.append(tree_map(detach, output_obj)) return loss else: return output_obj def backward_b_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, input_obj: Optional[dict], output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ) -> Optional[dict]: """Backward dx step of the pipeline; we calculate "dx = w*dy" here; Args: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model input_obj (Optional[dict]): x. output_obj (Union[dict, torch.Tensor]): y. output_obj_grad (dict): dy. Returns: Optional[dict]: dx. """ # calculate bwd b step ; only dx = w*dy; # Retain the grad on the input_obj. tree_map(retain_grad, input_obj) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss output_obj_grad = None optimizer.backward_by_grad( tensor=output_obj, grad=output_obj_grad, inputs=input_obj, retain_graph=True, ) return input_obj.grad def backward_w_step( self, model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict], ): """Backward dw step of the pipeline; we calculate "dw = x*dy" here; Args: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; optimizer (OptimizerWrapper): Optimizer to update the model output_obj (Union[dict, torch.Tensor]): y. output_obj_grad (dict): dy. Returns: Nothing need to return; we only calculate dw then update w; """ # calculate bwd w step ; only dw = x*dy; if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # loss backward; output_obj is loss output_obj_grad = None optimizer.backward_by_grad( tensor=output_obj, grad=output_obj_grad, inputs=list(model_chunk[model_chunk_id].parameters()), retain_graph=False, ) def schedule_f( self, scheduled_node, model_chunk: torch.nn.ModuleList, model_chunk_id: int, criterion: Callable, accum_loss: Optional[torch.Tensor] = None, outputs: Optional[List[Any]] = None, ): """A complete forward schedule; Include recv fwd --> cal fwd --> send fwd; Args: scheduled_node: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; criterion (Callable): loss function; accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None. outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None. Returns: Nothing. """ micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id) # Step1: recv fwd if model_chunk_id == 0: # is first stage; get input from func param if self.stage_manager.is_first_stage(ignore_chunk=True): input_obj = micro_batch else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) else: # is last stage; recv from local if self.stage_manager.is_last_stage(ignore_chunk=True): input_obj = self.local_send_forward_buffer.pop(0) # not last stage; recv from next else: input_obj = self.recv_forward_buffer[model_chunk_id].pop(0) input_obj.requires_grad_() # Step2: fwd step output_obj = self.forward_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, input_obj=input_obj, criterion=criterion, accum_loss=accum_loss, outputs=outputs, ) if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # We should not detach bwd LOSS pass else: detached_output_obj = output_obj.clone().detach() # Step3: send fwd # add output to send_fwd_buffer if model_chunk_id == 0: # is last stage; send to local_send_forward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): self.local_send_forward_buffer.append(detached_output_obj) else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) else: # is first stage; end of fwd; append LOSS to local_send_backward_buffer if self.stage_manager.is_first_stage(ignore_chunk=True): pass else: self.send_forward_buffer[model_chunk_id].append(detached_output_obj) # add input and output object for backward b self.input_tensors[model_chunk_id].append(input_obj) # detached output; for bwd b&w, we only need the graph(grad_fn) of output_obj deallocate_output_tensor(output_obj, deallocate_pipeline_outputs=True) self.output_tensors[model_chunk_id].append(output_obj) # add output object for backward w self.output_tensors_dw[model_chunk_id].append(output_obj) def schedule_b( self, scheduled_node, model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, # input_obj: Optional[dict], # output_obj: Union[dict, torch.Tensor], # output_obj_grad: Optional[dict], ): """A complete backward b schedule; Include recv bwd --> cal bwd step --> send bwd; Args: scheduled_node: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; Returns: Nothing. """ # Step1: recv bwd if model_chunk_id == 0: # chunk0 is last stage; recv output_grad from local_send_backward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): output_tensor_grad = self.local_send_backward_buffer.pop(0) # chunk 0 not last stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) else: # chunk1, is first stage; recv LOSS from local send bwd buffer if self.stage_manager.is_first_stage(ignore_chunk=True): output_tensor_grad = None # chunk1, not first stage; recv output_grad from recv_backward_buffer else: output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) # get input and output object from buffer; input_obj = self.input_tensors[model_chunk_id].pop(0) output_obj = self.output_tensors[model_chunk_id].pop(0) # save output_tensor_grad for dw if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): # we save loss here self.output_tensors_grad_dw[model_chunk_id].append(output_obj) else: # we save output_tensor_grad here self.output_tensors_grad_dw[model_chunk_id].append(output_tensor_grad) # _wait_p2p(recv_bwd_handles) # Step2: bwd step input_object_grad = self.backward_b_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, input_obj=input_obj, output_obj=output_obj, output_obj_grad=output_tensor_grad, ) # Step3: send bwd if model_chunk_id == 0: # do nothing; end of bwd; if self.stage_manager.is_first_stage(ignore_chunk=True): pass # save input_object_grad to send_backward_buffer else: self.send_backward_buffer[model_chunk_id].append(input_object_grad) else: # send to local_send_backward_buffer if self.stage_manager.is_last_stage(ignore_chunk=True): self.local_send_backward_buffer.append(input_object_grad) # send to next else: self.send_backward_buffer[model_chunk_id].append(input_object_grad) def schedule_w( self, scheduled_node, model_chunk: Union[ModuleList, Module], model_chunk_id: int, optimizer: OptimizerWrapper, ): """A complete backward w schedule; Include get y & dy from buffer --> cal bwd w step(cal dw & update w); Args: scheduled_node: model_chunk (ModuleList or Module): Model Chunk to be run; model_chunk_id (int): The current model chunk idx; Returns: Nothing. """ # get y & dy from buffer output_obj = self.output_tensors_dw[model_chunk_id].pop(0) output_obj_grad = self.output_tensors_grad_dw[model_chunk_id].pop(0) self.backward_w_step( model_chunk=model_chunk, model_chunk_id=model_chunk_id, optimizer=optimizer, output_obj=output_obj, output_obj_grad=output_obj_grad, ) def run_forward_only( self, model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], return_loss: bool = False, return_outputs: bool = False, ) -> Dict: assert self.forward_only # prepare batch self.load_batch(data_iter) # prepare accum loss & output accum_loss = None # reset accum loss at fwd end; if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None # while we still have schedules_node in self.schedules for it in range(len(self.schedules)): scheduled_node = self.schedules[it] if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: # communication communication_func = self.communication_map[scheduled_node.type] communication_func(scheduled_node.chunk) if scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, criterion=criterion, accum_loss=accum_loss, outputs=outputs, ) # return loss & output if outputs is not None: outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( self, model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, ) -> Dict: """ Runs Zerobubble schedule, with communication between pipeline stages. """ # prepare batch self.load_batch(data_iter) # prepare accum loss & output accum_loss = None # reset accum loss at fwd end; if return_loss and self.stage_manager.is_first_stage(ignore_chunk=True): accum_loss = torch.scalar_tensor(0, device=get_accelerator().get_current_device()) outputs = [] if return_outputs and self.stage_manager.is_first_stage(ignore_chunk=True) else None # while we still have schedules_node in self.schedules schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) for it in range(len(schedule)): scheduled_node = schedule[it] if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication communication_func = self.communication_map[scheduled_node.type] communication_func(scheduled_node.chunk) if scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, criterion=criterion, accum_loss=accum_loss, outputs=outputs, ) elif scheduled_node.type == "B": self.schedule_b( scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) elif scheduled_node.type == "W": self.schedule_w( scheduled_node=scheduled_node, model_chunk=model_chunk, model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) # return loss & output if outputs is not None: outputs = merge_batch(outputs) return {"loss": accum_loss, "outputs": outputs} def forward_backward_step( self, model_chunk: Union[ModuleList, Module], data_iter: Iterable, criterion: Callable[..., Any], optimizer: Optional[OptimizerWrapper] = None, return_loss: bool = False, return_outputs: bool = False, ) -> dict: """ Args: model_chunk (ModuleList or Module): Model Chunk to be trained. Original interleaved uses a module list whereas shardformer uses entire model + layer specification data_iter (Iterable): Data iterator. criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor. optimizer (OptimizerWrapper, optional): Optimizer to be used. Can be None when only forward is executed. Defaults to None. return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss. return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs. Returns: dict: A dict with keys: 'loss' and 'outputs'. """ self.forward_only = not torch.is_grad_enabled() if optimizer is None: assert self.forward_only, "Optimizer should be passed when doing backward." if self.forward_only: result = self.run_forward_only(model_chunk, data_iter, criterion, return_loss, return_outputs) else: result = self.run_forward_backward( model_chunk, data_iter, criterion, optimizer, return_loss, return_outputs ) self.assert_buffer_empty() return result