mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			1274 lines
		
	
	
		
			55 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			1274 lines
		
	
	
		
			55 KiB
		
	
	
	
		
			Python
		
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
 | |
| 
 | |
| from contextlib import contextmanager
 | |
| from typing import Callable, List, Optional, Tuple, Union
 | |
| 
 | |
| import torch.cuda
 | |
| 
 | |
| import internlm.core.communication as comm
 | |
| from internlm.core.context import ParallelMode
 | |
| from internlm.core.context import global_context as gpc
 | |
| from internlm.core.engine import Engine
 | |
| from internlm.core.naive_amp import NaiveAMPModel
 | |
| from internlm.utils.common import get_current_device, move_to_device
 | |
| from internlm.utils.logger import get_logger
 | |
| 
 | |
| from .base_scheduler import BaseScheduler, SchedulerHook
 | |
| 
 | |
| logger = get_logger(__file__)
 | |
| 
 | |
| 
 | |
| def get_tensor_shape():
 | |
|     if hasattr(gpc.config, "TENSOR_SHAPE"):
 | |
|         return gpc.config.TENSOR_SHAPE
 | |
| 
 | |
|     if not gpc.is_initialized(ParallelMode.PIPELINE):
 | |
|         return None
 | |
| 
 | |
|     if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
 | |
|         if gpc.config.model.use_flash_attn:
 | |
|             if gpc.config.model.sequence_parallel:
 | |
|                 sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
 | |
|                 tensor_shape = (
 | |
|                     gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
 | |
|                     gpc.config.HIDDEN_SIZE,
 | |
|                 )
 | |
|             else:
 | |
|                 tensor_shape = (
 | |
|                     gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
 | |
|                     gpc.config.HIDDEN_SIZE,
 | |
|                 )
 | |
|         else:
 | |
|             tensor_shape = (
 | |
|                 gpc.config.data["micro_bsz"],
 | |
|                 gpc.config.SEQ_LEN,
 | |
|                 gpc.config.HIDDEN_SIZE,
 | |
|             )
 | |
|         return tensor_shape
 | |
|     else:
 | |
|         return None
 | |
| 
 | |
| 
 | |
| def pack_return_tensors(return_tensors):
 | |
|     output, label = tuple(zip(*return_tensors))
 | |
|     if isinstance(output[0], torch.Tensor):
 | |
|         output = torch.cat(output, dim=0)
 | |
|     elif isinstance(output[0], (list, tuple)):
 | |
|         output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
 | |
|     else:
 | |
|         raise TypeError("Output of model must be tensor or list/tuple of tensors")
 | |
|     if isinstance(label[0], torch.Tensor):
 | |
|         label = torch.cat(label, dim=0)
 | |
|     else:
 | |
|         merged_label = {k: [] for k in label[0].keys()}
 | |
|         for d in label:
 | |
|             for k, v in d.items():
 | |
|                 merged_label[k].append(v)
 | |
|         label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
 | |
|     return output, label
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def switch_virtual_pipeline_parallel_rank(rank):
 | |
|     prev_rank = gpc.virtual_pipeline_parallel_rank
 | |
|     try:
 | |
|         gpc.set_virtual_pipeline_parallel_rank(rank)
 | |
|         yield
 | |
|     finally:
 | |
|         gpc.set_virtual_pipeline_parallel_rank(prev_rank)
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def switch_optimizer_grad_sync_skip_mode(optimizer, skip: bool = True):
 | |
|     prev_mode = optimizer.skip_grad_reduce
 | |
|     try:
 | |
|         optimizer.skip_grad_reduce = skip
 | |
|         yield
 | |
|     finally:
 | |
|         optimizer.skip_grad_reduce = prev_mode
 | |
| 
 | |
| 
 | |
| class PipelineScheduler(BaseScheduler):
 | |
|     """
 | |
|     A helper schedule class for pipeline parallelism running environment.
 | |
|     It uses non-interleaved 1F1B strategy. Other properties are similar as
 | |
|     :class:`NonPipelineSchedule`.
 | |
| 
 | |
|     Args:
 | |
|         num_microbatches (int): The number of microbatches.
 | |
|         dtype (torch.dtype): Type of data. torch.float by default.
 | |
|         data_process_func (Callable, optional):
 | |
|             The post processing function which receives a micro batch of data, and it will be executed
 | |
|             in `load_micro_batch`.
 | |
|         tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
 | |
|         scatter_gather_tensors (bool, optional):
 | |
|             If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
 | |
|         scheduler_hooks (Optional[List[SchedulerHook]], optional): List of scheduler hooks.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         num_microbatches: int,
 | |
|         dtype: torch.dtype = torch.float,
 | |
|         data_process_func: Callable = None,
 | |
|         tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
 | |
|         scatter_gather_tensors: bool = False,
 | |
|         scheduler_hooks: Optional[List[SchedulerHook]] = None,
 | |
|     ):
 | |
|         assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}"
 | |
| 
 | |
|         assert not isinstance(
 | |
|             tensor_shape, int
 | |
|         ), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
 | |
| 
 | |
|         super().__init__(data_process_func=data_process_func)
 | |
| 
 | |
|         self.num_microbatches = num_microbatches
 | |
|         self.dtype = dtype
 | |
|         self._hooks = scheduler_hooks
 | |
| 
 | |
|         self.tensor_shape = (
 | |
|             tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
 | |
|         )
 | |
| 
 | |
|         self.scatter_gather_tensors = (
 | |
|             scatter_gather_tensors
 | |
|             and gpc.is_initialized(ParallelMode.TENSOR)
 | |
|             and gpc.get_world_size(ParallelMode.TENSOR) > 1
 | |
|         )
 | |
| 
 | |
|         if gpc.config.model.sequence_parallel:
 | |
|             self.scatter_gather_tensors = False
 | |
| 
 | |
|         # cache for the batch data
 | |
|         self.batch_data = None
 | |
| 
 | |
|     def pre_processing(self, engine):
 | |
|         types = set()
 | |
| 
 | |
|         for param in engine.model.parameters():
 | |
|             types.add(param.dtype)
 | |
|         assert len(types) == 1, f"Mixed types of parameter detected, {types}"
 | |
| 
 | |
|         self.dtype = types.pop()
 | |
| 
 | |
|     @staticmethod
 | |
|     def _call_engine(engine, data):  # pylint: disable=W0237
 | |
|         if data is None:
 | |
|             return None
 | |
| 
 | |
|         if isinstance(data, torch.Tensor):
 | |
|             return engine(data)
 | |
|         elif isinstance(data, (list, tuple)):
 | |
|             return engine(*data)
 | |
|         elif isinstance(data, dict):
 | |
|             stage_output = data.pop("stage_output", None)
 | |
| 
 | |
|             if stage_output is None:
 | |
|                 return engine(**data)
 | |
|             elif isinstance(stage_output, torch.Tensor):
 | |
|                 return engine(stage_output, **data)
 | |
|             elif isinstance(stage_output, (tuple, list)):
 | |
|                 return engine(*stage_output, **data)
 | |
|             else:
 | |
|                 raise TypeError(
 | |
|                     f"Expected stage_output to be of type torch.Tensor, list, or tuple, "
 | |
|                     f"but got {type(stage_output)}"
 | |
|                 )
 | |
|         else:
 | |
|             raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
 | |
| 
 | |
|     def load_batch(self, engine, data_iter):
 | |
|         # Pipeline schedule just puts data in memory
 | |
|         batch_data, batch_size = engine.load_batch(data_iter, to_gpu=False)
 | |
|         assert batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
 | |
| 
 | |
|         self.microbatch_offset = 0
 | |
|         self.batch_size = batch_size
 | |
|         self.batch_data, self.batch_label = batch_data
 | |
|         self.microbatch_size = self.batch_size // self.num_microbatches
 | |
| 
 | |
|     def load_micro_batch(self):
 | |
|         micro_batch_data, micro_batch_label = self._load_micro_batch(
 | |
|             data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
 | |
|         )
 | |
|         if self.data_process_func:
 | |
|             micro_batch_data["input_ids"] = self.data_process_func(
 | |
|                 micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
 | |
|             )
 | |
|             micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
 | |
| 
 | |
|             micro_batch_data.pop("cu_seqlens")
 | |
|             micro_batch_data.pop("indexes")
 | |
| 
 | |
|         micro_batch_data["label"] = micro_batch_label
 | |
|         self.microbatch_offset += self.microbatch_size
 | |
| 
 | |
|         return move_to_device(micro_batch_data)
 | |
| 
 | |
|     def _get_data_label_for_current_step(self, stage_output, micro_batch_data):
 | |
|         if isinstance(micro_batch_data, (tuple, list)):
 | |
|             if gpc.is_first_rank(ParallelMode.PIPELINE):
 | |
|                 # for the first stage, we use the data from the
 | |
|                 # dataloader output by default
 | |
|                 data, label = micro_batch_data
 | |
|             else:
 | |
|                 # for non-first stage, we use the output passed
 | |
|                 # by the previous as the model input
 | |
|                 data = stage_output
 | |
|                 _, label = micro_batch_data
 | |
|         elif isinstance(micro_batch_data, dict):
 | |
|             label = micro_batch_data.pop("label", None)
 | |
|             data = {"stage_output": stage_output, **micro_batch_data}
 | |
| 
 | |
|         return data, label
 | |
| 
 | |
|     def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
 | |
|         for hook in self._hooks:
 | |
|             getattr(hook, func_name)(self, *args, **kwargs)
 | |
| 
 | |
|     def _get_current_microbatch_id(self, step_id: int) -> int:
 | |
|         """
 | |
|         Get the current microbatch ID based on the step ID.
 | |
|         In 1f1b scheduler, the microbatch ID is the same as the step ID,
 | |
|         but it is important to note that the step ID is calculated separately
 | |
|         for forward and backward passes.
 | |
|         """
 | |
|         return step_id
 | |
| 
 | |
|     def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
 | |
|         """
 | |
|         Forward step for passed-in model. If it is the first stage, the input tensor
 | |
|         is obtained from data_iterator, otherwise the passed-in input_obj is used.
 | |
|         Returns output tensor. This is a helper function and can be ignored by users.
 | |
| 
 | |
|         Args:
 | |
|             engine (colossalai.engine.Engine): Colossalai engine for training and inference.
 | |
|             input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
 | |
|             return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
 | |
|             return_output_label (bool, optional): Whether returns output labels.
 | |
|             accum_loss (optional): Where accumulated loss stores.
 | |
|         Returns:
 | |
|             Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
 | |
|                 pipeline stage.
 | |
|         """
 | |
|         micro_batch_data = self.load_micro_batch()
 | |
|         data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
 | |
| 
 | |
|         self._call_hooks("before_forward", data)
 | |
|         output_obj = self._call_engine(engine.model, data)
 | |
|         self._call_hooks("after_forward", output_obj)
 | |
| 
 | |
|         if gpc.is_last_rank(ParallelMode.PIPELINE):
 | |
|             self._call_hooks("post_helper_func", output_obj, label)
 | |
|             if return_output_label:
 | |
|                 return_tensors.append((output_obj, label))
 | |
|             if accum_loss is not None:
 | |
|                 self._call_hooks("before_criterion", output_obj, label)
 | |
|                 loss = self._call_engine_criterion(engine, output_obj, label)
 | |
|                 self._call_hooks("after_criterion", loss)
 | |
| 
 | |
|                 loss_reduced = loss / self.num_microbatches
 | |
|                 accum_loss.add_(loss_reduced.detach())
 | |
|                 output_obj = loss_reduced
 | |
| 
 | |
|         return output_obj
 | |
| 
 | |
|     def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
 | |
|         """
 | |
|         Backward step through the passed-in output tensor. If it is the last stage, the
 | |
|         output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
 | |
|         Returns the gradients with respect to the input tensor (None if first stage).
 | |
|         This is a helper function and can be ignored by users.
 | |
| 
 | |
|         Args:
 | |
|             engine (colossalai.engine.Engine): Colossalai engine for training and inference.
 | |
|             step_id (int): The ID of the current step.
 | |
|             input_obj (Union[torch.Tensor, List[torch.Tensor]]): Input tensor for this stage.
 | |
|             output_obj (Union[torch.Tensor, List[torch.Tensor]]): Output tensor for this stage.
 | |
|             output_obj_grad (Union[torch.Tensor, List[torch.Tensor]]): Gradient of output tensor for this stage.
 | |
| 
 | |
|         Returns:
 | |
|             Union[torch.Tensor, List[torch.Tensor]]: Gradient of input tensor.
 | |
|         """
 | |
| 
 | |
|         # Retain the grad on the input_obj.
 | |
|         if input_obj is not None:
 | |
|             if isinstance(input_obj, torch.Tensor):
 | |
|                 input_obj.retain_grad()
 | |
|             else:
 | |
|                 for in_tensor in input_obj:
 | |
|                     if in_tensor is not None:
 | |
|                         in_tensor.retain_grad()
 | |
| 
 | |
|         # Backward pass.
 | |
| 
 | |
|         # Only the last microbatch does syncing grad.
 | |
|         skip_grad_sync = self._get_current_microbatch_id(step_id) != self.num_microbatches - 1
 | |
| 
 | |
|         self._call_hooks("before_backward", output_obj, output_obj_grad)
 | |
|         with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
 | |
|             if output_obj_grad is None:
 | |
|                 engine.backward(output_obj)
 | |
|             else:
 | |
|                 engine.backward_by_grad(output_obj, output_obj_grad)
 | |
| 
 | |
|         # Collect the grad of the input_obj.
 | |
|         input_obj_grad = None
 | |
|         if input_obj is not None:
 | |
|             if isinstance(input_obj, torch.Tensor):
 | |
|                 input_obj_grad = input_obj.grad
 | |
|             else:
 | |
|                 input_obj_grad = []
 | |
|                 for in_tensor in input_obj:
 | |
|                     input_obj_grad.append(in_tensor.grad)
 | |
|         self._call_hooks("after_backward", input_obj_grad)
 | |
| 
 | |
|         return input_obj_grad
 | |
| 
 | |
|     def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
 | |
|         """
 | |
|         This function performs forward only computation process. The scheduling of microbatches is similar to the
 | |
|         warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
 | |
|         the forward computation, and finally passes the forward computation output to the next stage. There are two
 | |
|         special cases to note:
 | |
|         1. The first stage of the pipeline does not need to receive forward input; its input comes from the dataloader.
 | |
|         2. The last stage of the pipeline does not need to send forward output; its output is returned to the user code
 | |
|            for processing.
 | |
| 
 | |
|         Args:
 | |
|             engine (colossalai.engine.Engine): internlm engine for training and inference.
 | |
|             return_loss (bool, optional): Whether to return the accumulated loss.
 | |
|             return_output_label (bool, optional): Whether to return outputs and labels.
 | |
| 
 | |
|         Returns:
 | |
|             Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
 | |
|                 output, label, and accumulated loss.
 | |
|         """
 | |
| 
 | |
|         # Input, output tensors only need to be saved when doing backward passes
 | |
|         return_tensors = []
 | |
|         accum_loss = (
 | |
|             torch.zeros(1, device=get_current_device())
 | |
|             if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
 | |
|             else None
 | |
|         )
 | |
| 
 | |
|         # Used for tensor meta information communication
 | |
|         forward_recv_shapes = self.tensor_shape
 | |
|         need_forward_meta = self.tensor_shape is None
 | |
| 
 | |
|         # Run all forward passes.
 | |
|         for _ in range(self.num_microbatches):
 | |
|             # Receive input from the previous stage
 | |
|             if not gpc.is_first_rank(ParallelMode.PIPELINE):
 | |
|                 if forward_recv_shapes is None:
 | |
|                     forward_recv_shapes = comm.recv_obj_meta()
 | |
|                 input_obj = comm.recv_forward(
 | |
|                     forward_recv_shapes,
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
|             else:
 | |
|                 input_obj = None
 | |
| 
 | |
|             # Perform forward computation
 | |
|             output_obj = self._forward_step(
 | |
|                 engine,
 | |
|                 input_obj,
 | |
|                 return_tensors,
 | |
|                 return_output_label=return_output_label,
 | |
|                 accum_loss=accum_loss,
 | |
|             )
 | |
| 
 | |
|             if not gpc.is_last_rank(ParallelMode.PIPELINE):
 | |
|                 if need_forward_meta:
 | |
|                     comm.send_obj_meta(output_obj)
 | |
|                     need_forward_meta = False  # send only once.
 | |
|                 # Send the forward computation output to the next stage
 | |
|                 comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
 | |
| 
 | |
|         output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
 | |
| 
 | |
|         return output, label, accum_loss
 | |
| 
 | |
|     def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
 | |
|         """
 | |
|         This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
 | |
|         It consists of three stages: warmup, 1F1B, and cooldown.
 | |
| 
 | |
|         1. Warmup Stage:
 | |
|         The warmup stage performs num_warmup forward microsteps. The calculation of num_warmup is the pipeline length
 | |
|         minus the rank of the current pipeline minus 1. For each microstep, it receives data as input from the previous
 | |
|         stage, performs the forward computation, and then sends the result to the next stage.
 | |
| 
 | |
|         2. 1F1B Stage:
 | |
|         The 1F1B stage consists of pairs of forward and backward microsteps. It performs num_1f1b_micropairs iterations,
 | |
|         where num_1f1b_micropairs is calculated as the total number of microbatches minus the number of microbatches in
 | |
|         the warmup stage. In each iteration, it first performs a forward computation, sends the result to the next
 | |
|         stage, receives input for the backward computation, performs the backward computation, and finally sends the
 | |
|         result to the previous stage to receive input for the next forward computation.
 | |
| 
 | |
|         3. Cooldown Stage:
 | |
|         The cooldown stage performs the same number of iterations as the warmup stage. In each iteration, it receives
 | |
|         input for the backward computation, performs the backward computation, and finally sends the result to the
 | |
|         previous stage.
 | |
| 
 | |
|         There are two special cases to consider:
 | |
|         1. The first stage of the pipeline does not need to receive forward input or send backward output. The last
 | |
|         stage does not need to send forward output or receive backward input.
 | |
|         2. Pay attention to the communication between stages and use additional communication to bridge the gap.
 | |
| 
 | |
|         Args:
 | |
|             engine (Engine): The engine used for computation.
 | |
|             return_loss (bool, optional): Whether to return the accumulated loss.
 | |
|             return_output_label (bool, optional): Whether to return outputs and labels.
 | |
| 
 | |
|         Returns:
 | |
|             Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
 | |
|             The output, label, and accumulated loss.
 | |
|         """
 | |
| 
 | |
|         num_warmup_microsteps = (
 | |
|             gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
 | |
|         )
 | |
|         num_warmup_microsteps = min(num_warmup_microsteps, self.num_microbatches)
 | |
|         num_1f1b_micropairs = self.num_microbatches - num_warmup_microsteps
 | |
| 
 | |
|         # Input, output tensors only need to be saved when doing backward passes
 | |
|         input_objs = []
 | |
|         output_objs = []
 | |
|         return_tensors = []
 | |
|         accum_loss = (
 | |
|             torch.zeros(1, device=get_current_device())
 | |
|             if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
 | |
|             else None
 | |
|         )
 | |
| 
 | |
|         # Used for tensor meta information communication
 | |
|         forward_recv_shapes = self.tensor_shape
 | |
|         backward_recv_shapes = None
 | |
|         need_forward_meta = self.tensor_shape is None
 | |
| 
 | |
|         # Run warmup forward passes.
 | |
|         for i in range(num_warmup_microsteps):
 | |
|             # Receive the input from the previous stage
 | |
|             if not gpc.is_first_rank(ParallelMode.PIPELINE):
 | |
|                 if forward_recv_shapes is None:
 | |
|                     forward_recv_shapes = comm.recv_obj_meta()
 | |
|                 input_obj = comm.recv_forward(
 | |
|                     forward_recv_shapes,
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
|             else:
 | |
|                 input_obj = None
 | |
| 
 | |
|             # Perform forward computation
 | |
|             output_obj = self._forward_step(
 | |
|                 engine,
 | |
|                 input_obj,
 | |
|                 return_tensors,
 | |
|                 return_output_label=return_output_label,
 | |
|                 accum_loss=accum_loss,
 | |
|             )
 | |
| 
 | |
|             if not gpc.is_last_rank(ParallelMode.PIPELINE):
 | |
|                 if isinstance(output_obj, torch.Tensor):
 | |
|                     backward_recv_shapes = output_obj.shape
 | |
|                 else:
 | |
|                     backward_recv_shapes = [out_tensor.shape for out_tensor in output_obj]
 | |
| 
 | |
|                 if need_forward_meta:
 | |
|                     comm.send_obj_meta(output_obj)
 | |
|                     need_forward_meta = False  # send only once.
 | |
| 
 | |
|             # Send the output of forward computation of this pipeline stage to the next pipeline stage as input for
 | |
|             # forward computation
 | |
|             if not gpc.is_last_rank(ParallelMode.PIPELINE):
 | |
|                 comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
 | |
| 
 | |
|             input_objs.append(input_obj)
 | |
|             output_objs.append(output_obj)
 | |
| 
 | |
|         # Before running 1F1B, need to receive first forward tensor.
 | |
|         # If all microbatches are run in warmup / cooldown phase, then no need to
 | |
|         # receive this tensor here.
 | |
|         if num_1f1b_micropairs > 0:
 | |
|             if not gpc.is_first_rank(ParallelMode.PIPELINE):
 | |
|                 if forward_recv_shapes is None:
 | |
|                     forward_recv_shapes = comm.recv_obj_meta(forward_recv_shapes)
 | |
|                 input_obj = comm.recv_forward(
 | |
|                     forward_recv_shapes,
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
|             else:
 | |
|                 input_obj = None
 | |
| 
 | |
|         # Run 1F1B in steady state.
 | |
|         for i in range(num_1f1b_micropairs):
 | |
|             # Perform forward computation
 | |
|             output_obj = self._forward_step(
 | |
|                 engine,
 | |
|                 input_obj,
 | |
|                 return_tensors,
 | |
|                 return_output_label=return_output_label,
 | |
|                 accum_loss=accum_loss,
 | |
|             )
 | |
| 
 | |
|             if gpc.is_last_rank(ParallelMode.PIPELINE):
 | |
|                 output_obj_grad = None
 | |
|             else:
 | |
|                 output_obj_grad = comm.send_forward_recv_backward(
 | |
|                     output_obj,
 | |
|                     backward_recv_shapes,
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
| 
 | |
|             # Add input_obj and output_obj to end of list.
 | |
|             input_objs.append(input_obj)
 | |
|             output_objs.append(output_obj)
 | |
| 
 | |
|             # Pop output_obj and output_obj from the start of the list for
 | |
|             # the backward pass.
 | |
|             input_obj = input_objs.pop(0)
 | |
|             output_obj = output_objs.pop(0)
 | |
| 
 | |
|             input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
 | |
| 
 | |
|             if i == (num_1f1b_micropairs - 1):
 | |
|                 input_obj = None
 | |
|                 if not gpc.is_first_rank(ParallelMode.PIPELINE):
 | |
|                     comm.send_backward(
 | |
|                         input_obj_grad,
 | |
|                         scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                     )
 | |
|             else:
 | |
|                 if gpc.is_first_rank(ParallelMode.PIPELINE):
 | |
|                     input_obj = None
 | |
|                 else:
 | |
|                     input_obj = comm.send_backward_recv_forward(
 | |
|                         input_obj_grad,
 | |
|                         forward_recv_shapes,
 | |
|                         dtype=self.dtype,
 | |
|                         scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                     )
 | |
| 
 | |
|         # Run cooldown backward passes.
 | |
|         for i in range(num_warmup_microsteps):
 | |
|             input_obj = input_objs.pop(0)
 | |
|             output_obj = output_objs.pop(0)
 | |
| 
 | |
|             if not gpc.is_last_rank(ParallelMode.PIPELINE):
 | |
|                 output_obj_grad = comm.recv_backward(
 | |
|                     backward_recv_shapes,
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
|             else:
 | |
|                 output_obj_grad = None
 | |
| 
 | |
|             input_obj_grad = self._backward_step(
 | |
|                 engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
 | |
|             )
 | |
| 
 | |
|             if not gpc.is_first_rank(ParallelMode.PIPELINE):
 | |
|                 comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
 | |
| 
 | |
|         output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
 | |
| 
 | |
|         return output, label, accum_loss
 | |
| 
 | |
|     def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
 | |
|         """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
 | |
|         Returns a tuple with losses if the last stage, an empty tuple otherwise.
 | |
| 
 | |
|         Args:
 | |
|             engine (colossalai.engine.Engine): Colossalai engine for training and inference.
 | |
|             data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
 | |
|             forward_only (bool, optional):
 | |
|                 Whether run forward step only. Default is false. If true, no backward will be run.
 | |
|             return_loss (bool, optional): Whether returns the loss value. Default is true.
 | |
|             return_output_label (bool, optional): If False, the output and label won't be returned.
 | |
|         Returns:
 | |
|             Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
 | |
|         """
 | |
| 
 | |
|         assert (
 | |
|             forward_only or return_loss
 | |
|         ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
 | |
| 
 | |
|         # Load data first
 | |
|         self.load_batch(engine, data_iter)
 | |
| 
 | |
|         if forward_only:
 | |
|             return self._forward_only_step(engine, return_loss, return_output_label)
 | |
|         else:
 | |
|             return self._forward_backward_step(engine, return_loss, return_output_label)
 | |
| 
 | |
| 
 | |
| class InterleavedPipelineScheduler(PipelineScheduler):
 | |
|     """
 | |
|     Interleaved Pipeline Scheduler.
 | |
|     """
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         num_microbatches: int,
 | |
|         num_chunks: int,
 | |
|         dtype: torch.dtype = torch.float,
 | |
|         data_process_func: Callable = None,
 | |
|         tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
 | |
|         scatter_gather_tensors: bool = False,
 | |
|         scheduler_hooks: Optional[List[SchedulerHook]] = None,
 | |
|         communication_overlap: bool = False,
 | |
|     ):
 | |
|         """A helper schedule class for pipeline parallelism running environment.
 | |
|         It uses interleaved 1F1B strategy. Other properties are similar as
 | |
|         :class:`NonPipelineSchedule`.
 | |
| 
 | |
|         Args:
 | |
|             num_microbatches (int): The number of microbatches.
 | |
|             num_chunks (int): The number of model chunks.
 | |
|             dtype (torch.dtype, optional): The data type of the tensors. Default is torch.float.
 | |
|             data_process_func (Callable, optional):
 | |
|                 The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
 | |
|             tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
 | |
|             scatter_gather_tensors (bool, optional):
 | |
|                 If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
 | |
|             scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None.
 | |
|             communication_overlap (bool, optional): Whether to enable communication overlap. Default is False.
 | |
|         """
 | |
|         assert (
 | |
|             num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0
 | |
|         ), "num_microbatches must be an integer multiple of pipeline parallel world size"
 | |
| 
 | |
|         assert (
 | |
|             isinstance(num_chunks, int) and num_chunks > 0
 | |
|         ), f"expected num_chunks to be an integer and larger than 0, but got {num_chunks}"
 | |
| 
 | |
|         super().__init__(
 | |
|             num_microbatches,
 | |
|             dtype=dtype,
 | |
|             data_process_func=data_process_func,
 | |
|             tensor_shape=tensor_shape,
 | |
|             scatter_gather_tensors=scatter_gather_tensors,
 | |
|             scheduler_hooks=scheduler_hooks,
 | |
|         )
 | |
| 
 | |
|         gpc.set_virtual_pipeline_parallel_size(num_chunks)
 | |
|         gpc.set_virtual_pipeline_parallel_rank(0)
 | |
| 
 | |
|         self._num_chunks = num_chunks
 | |
|         self._communication_overlap = communication_overlap
 | |
|         # switch 1f1b loop runner function according to communication overlap
 | |
|         self._run_1f1b_loop = (
 | |
|             self._run_1f1b_loop_with_overlap if communication_overlap else self._run_1f1b_loop_without_overlap
 | |
|         )
 | |
| 
 | |
|         # states
 | |
|         self._pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
 | |
|         self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
 | |
| 
 | |
|         self._accum_loss = None
 | |
|         self._return_tensors = None
 | |
|         self._input_objs = [[] for _ in range(num_chunks)]
 | |
|         self._output_objs = [[] for _ in range(num_chunks)]
 | |
|         self._output_obj_grads = [[] for _ in range(num_chunks)]
 | |
| 
 | |
|         self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
 | |
|         self._output_obj_shapes = [None for _ in range(num_chunks)]
 | |
|         self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
 | |
| 
 | |
|     def _clear_state(self) -> None:
 | |
|         self._accum_loss = None
 | |
|         self._return_tensors = None
 | |
|         self._input_objs = [[] for _ in range(self._num_chunks)]
 | |
|         self._output_objs = [[] for _ in range(self._num_chunks)]
 | |
|         self._output_obj_grads = [[] for _ in range(self._num_chunks)]
 | |
| 
 | |
|         self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
 | |
|         self._output_obj_shapes = [None for _ in range(self._num_chunks)]
 | |
|         self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(self._num_chunks)]
 | |
| 
 | |
|     def load_batch(self, engine, data_iter):
 | |
|         super().load_batch(engine, data_iter)
 | |
|         # overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
 | |
|         self.microbatch_offset = [0 for _ in range(self._num_chunks)]
 | |
| 
 | |
|     def load_micro_batch(self, model_chunk_id):
 | |
|         micro_batch_data, micro_batch_label = self._load_micro_batch(
 | |
|             data=self.batch_data,
 | |
|             label=self.batch_label,
 | |
|             offset=self.microbatch_offset[model_chunk_id],
 | |
|             micro_bsz=self.microbatch_size,
 | |
|         )
 | |
|         micro_batch_data["label"] = micro_batch_label
 | |
|         self.microbatch_offset[model_chunk_id] += self.microbatch_size
 | |
|         return move_to_device(micro_batch_data)
 | |
| 
 | |
|     def _forward_step(self, engine, chunk_id):
 | |
|         """Forward step for passed-in model. If it is the first stage, the input tensor
 | |
|         is obtained from data_iterator, otherwise the passed-in input_obj is used.
 | |
|         Returns output tensor. This is a helper function and can be ignored by users.
 | |
| 
 | |
|         Args:
 | |
|             engine (colossalai.engine.Engine): Colossalai engine for training and inference.
 | |
|             chunk_id (int): The id of model chunks.
 | |
|         Returns:
 | |
|             Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
 | |
|                 pipeline stage.
 | |
|         """
 | |
|         gpc.set_virtual_pipeline_parallel_rank(chunk_id)
 | |
| 
 | |
|         if gpc.is_pipeline_first_stage() and len(self._input_objs[chunk_id]) == len(self._output_objs[chunk_id]):
 | |
|             self._input_objs[chunk_id].append(None)
 | |
|         input_obj = self._input_objs[chunk_id][-1]
 | |
| 
 | |
|         micro_batch_data = self.load_micro_batch(chunk_id)
 | |
|         data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
 | |
| 
 | |
|         self._call_hooks("before_forward", data)
 | |
|         output_obj = self._call_engine(engine.model[chunk_id], data)
 | |
|         # Convert output_obj to fp32 when last model chunk of last stage
 | |
|         if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
 | |
|             output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
 | |
|         self._call_hooks("after_forward", output_obj)
 | |
| 
 | |
|         if gpc.is_pipeline_last_stage():
 | |
|             self._call_hooks("post_helper_func", output_obj, label)
 | |
| 
 | |
|             if self._return_tensors is not None:
 | |
|                 self._return_tensors.append((output_obj, label))
 | |
|             if self._accum_loss is not None:
 | |
|                 self._call_hooks("before_criterion", output_obj, label)
 | |
|                 loss = self._call_engine_criterion(engine, output_obj, label)
 | |
|                 self._call_hooks("after_criterion", loss)
 | |
| 
 | |
|                 loss_reduced = loss / self.num_microbatches
 | |
|                 self._accum_loss.add_(loss_reduced.detach())
 | |
|                 output_obj = loss_reduced
 | |
| 
 | |
|         self._output_objs[chunk_id].append(output_obj)
 | |
| 
 | |
|         return output_obj
 | |
| 
 | |
|     def _backward_step(self, engine, chunk_id, step_id):
 | |
|         """
 | |
|         Backward step for passed-in model. If it is the last stage, the input tensor
 | |
|         is obtained from the previous forward step, otherwise the passed-in input_obj is used.
 | |
|         Returns input tensor gradient. This is a helper function and can be ignored by users.
 | |
| 
 | |
|         Args:
 | |
|             engine (colossalai.engine.Engine): Colossalai engine for training and inference.
 | |
|             chunk_id (int): The id of model chunks.
 | |
|             step_id (int): The current step id.
 | |
| 
 | |
|         Returns:
 | |
|             Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: input tensor gradient.
 | |
|         """
 | |
|         gpc.set_virtual_pipeline_parallel_rank(chunk_id)
 | |
| 
 | |
|         if gpc.is_pipeline_last_stage() and len(self._output_obj_grads[chunk_id]) == 0:
 | |
|             self._output_obj_grads[chunk_id].append(None)
 | |
| 
 | |
|         input_obj = self._input_objs[chunk_id].pop(0)
 | |
|         output_obj = self._output_objs[chunk_id].pop(0)
 | |
|         output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
 | |
| 
 | |
|         input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad)
 | |
| 
 | |
|         return input_obj_grad
 | |
| 
 | |
|     def _get_chunk_by_microbatch(self, step_id: int, backward: bool = False) -> int:
 | |
|         """Helper method to get the model chunk ID given the iteration number."""
 | |
|         microbatch_id_in_group = step_id % (self._pp_size * self._num_chunks)
 | |
|         chunk_id = microbatch_id_in_group // self._pp_size
 | |
| 
 | |
|         if backward:
 | |
|             chunk_id = self._num_chunks - chunk_id - 1
 | |
| 
 | |
|         return chunk_id
 | |
| 
 | |
|     def _get_current_microbatch_id(self, step_id: int) -> int:
 | |
|         # format:
 | |
|         # microstep_id : 1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16
 | |
|         # microbatch_id: 1  2  3  4  1  2  3  4  5  6  7  8  5  6  7  8
 | |
|         num_microbatch_group = step_id // (self._pp_size * self._num_chunks)
 | |
|         step_id_in_group = step_id % (self._pp_size * self._num_chunks)
 | |
| 
 | |
|         microbatch_id = num_microbatch_group * self._pp_size + step_id_in_group % self._pp_size
 | |
| 
 | |
|         return microbatch_id
 | |
| 
 | |
|     def _run_warmup_loop(
 | |
|         self,
 | |
|         engine: Engine,
 | |
|         num_microsteps: int,
 | |
|         num_warmup_microsteps: int,
 | |
|         receive_extra_backward: bool = False,
 | |
|         forward_only: bool = False,
 | |
|     ) -> None:
 | |
|         """
 | |
|         Run the warm-up loop and prepare data for the 1F1B stage.
 | |
| 
 | |
|         During the warm-up process, for each execution, it first performs a forward computation,
 | |
|         and then sends the computation result to the next stage.
 | |
|         It also receives data for the next forward computation.
 | |
|         Since the input for the first forward computation is not considered initially,
 | |
|         it needs to receive data once at the beginning.
 | |
| 
 | |
|         After the warm-up is completed, we need to prepare data for the 1F1B stage.
 | |
|         The data preparation process should be consistent with the communication method of the 1F1B stage.
 | |
| 
 | |
|         Args:
 | |
|             engine (Engine): The engine to run the warm-up loop.
 | |
|             num_microsteps (int): The total number of microsteps.
 | |
|             num_warmup_microsteps (int): The number of warm-up microsteps.
 | |
|             receive_extra_backward (bool, optional): Whether to receive extra backward input for the 1F1B stage.
 | |
|                                                      Default is False.
 | |
|             forward_only (bool, optional): Whether to only perform forward pass. Default is False.
 | |
|         """
 | |
|         if not gpc.is_pipeline_first_stage():
 | |
|             if self._input_obj_shapes[0] is None:
 | |
|                 self._input_obj_shapes[0] = comm.recv_obj_meta(self._input_obj_shapes[0])
 | |
|             self._input_objs[0].append(
 | |
|                 comm.recv_forward(
 | |
|                     self._input_obj_shapes[0],
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
|             )
 | |
|         else:
 | |
|             self._input_objs[0].append(None)
 | |
| 
 | |
|         for k in range(num_warmup_microsteps):
 | |
|             chunk_id = self._get_chunk_by_microbatch(k)
 | |
| 
 | |
|             output_obj = self._forward_step(engine, chunk_id)
 | |
| 
 | |
|             if forward_only:
 | |
|                 # when forward-only, no need to save tensors for a backward pass
 | |
|                 self._input_objs[chunk_id].pop()
 | |
|                 self._output_objs[chunk_id].pop()
 | |
| 
 | |
|             if not gpc.is_pipeline_last_stage():
 | |
|                 if isinstance(output_obj, torch.Tensor):
 | |
|                     self._output_obj_shapes[chunk_id] = output_obj.shape
 | |
|                 else:
 | |
|                     self._output_obj_shapes[chunk_id] = [out_tensor.shape for out_tensor in output_obj]
 | |
| 
 | |
|                 if self._send_tensor_shape_flags[chunk_id]:
 | |
|                     comm.send_obj_meta(output_obj)
 | |
|                     self._send_tensor_shape_flags[chunk_id] = False  # send only once for each chunk.
 | |
| 
 | |
|             # Determine if tensor should be received from previous stage.
 | |
|             next_forward_chunk_id = self._get_chunk_by_microbatch(k + 1)
 | |
| 
 | |
|             with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
 | |
|                 if not gpc.is_pipeline_first_stage() and self._input_obj_shapes[next_forward_chunk_id] is None:
 | |
|                     self._input_obj_shapes[next_forward_chunk_id] = comm.recv_obj_meta()
 | |
|                 if k == (num_microsteps - 1) or gpc.is_pipeline_first_stage():
 | |
|                     input_shape = None
 | |
|                 else:
 | |
|                     input_shape = self._input_obj_shapes[next_forward_chunk_id]
 | |
| 
 | |
|             # Don't send tensor downstream if on last stage.
 | |
|             if gpc.is_pipeline_last_stage():
 | |
|                 output_obj = None
 | |
| 
 | |
|             # Send and receive tensors as appropriate (send tensors computed
 | |
|             # in this iteration; receive tensors for next iteration).
 | |
|             if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
 | |
|                 # Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
 | |
|                 input_obj = comm.send_forward_recv_forward(
 | |
|                     output_obj,
 | |
|                     input_shape,
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
|             else:
 | |
|                 # Receive output_obj_grad for next backward, if receive_extra_backward is True.
 | |
|                 if self._communication_overlap:
 | |
|                     # In this case, we should handle forward and backward communication separately, consistent with the
 | |
|                     # overlap version of the 1F1B stage
 | |
|                     input_obj = comm.send_forward_recv_forward(
 | |
|                         output_obj,
 | |
|                         input_shape,
 | |
|                         dtype=self.dtype,
 | |
|                         scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                     )
 | |
|                     output_obj_grad = comm.send_backward_recv_backward(
 | |
|                         None,  # nothing to send
 | |
|                         self._output_obj_shapes[self._num_chunks - 1],
 | |
|                         dtype=self.dtype,
 | |
|                         scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                     )
 | |
|                     self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
 | |
|                 else:
 | |
|                     # In this case, we should handle forward and backward communication together, consistent with the
 | |
|                     # non-overlap version of the 1F1B stage
 | |
|                     input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
 | |
|                         output_obj,
 | |
|                         None,  # no backward grad to send
 | |
|                         input_shape,
 | |
|                         self._output_obj_shapes[self._num_chunks - 1],
 | |
|                         dtype=self.dtype,
 | |
|                         scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                     )
 | |
|                     self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
 | |
| 
 | |
|             self._input_objs[next_forward_chunk_id].append(input_obj)
 | |
| 
 | |
|     def _run_1f1b_loop_with_overlap(
 | |
|         self,
 | |
|         engine: Engine,
 | |
|         num_warmup_microsteps: int,
 | |
|         num_1f1b_micropairs: int,
 | |
|         all_warmup_microsteps: bool = False,
 | |
|     ) -> None:
 | |
|         """
 | |
|         Run the 1F1B loop with overlap.
 | |
| 
 | |
|         The 1F1B loop with overlap consists of the following steps:
 | |
|         1. Perform the forward pass.
 | |
|         2. Check if the backward input is ready.
 | |
|         3. Send the forward output and receive the forward input for the next iteration.
 | |
|         4. Perform the backward pass.
 | |
|         5. Check if the forward input is ready.
 | |
|         6. Send the backward output and receive the backward input for the next iteration.
 | |
| 
 | |
|         Args:
 | |
|             engine (Engine): The engine to run the 1F1B loop.
 | |
|             num_warmup_microsteps (int): The number of warm-up microsteps.
 | |
|             num_1f1b_micropairs (int): The number of 1F1B micropairs.
 | |
|             all_warmup_microsteps (bool, optional): Whether to run all warm-up microsteps. Default is False.
 | |
|         """
 | |
| 
 | |
|         backward_async_communicator = None
 | |
| 
 | |
|         # Run 1F1B in steady state.
 | |
|         for k in range(num_1f1b_micropairs):
 | |
|             forward_microstep_id = k + num_warmup_microsteps
 | |
|             backward_microstep_id = k
 | |
|             forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
 | |
|             backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
 | |
| 
 | |
|             # 1. Forward pass.
 | |
|             output_obj = self._forward_step(engine, forward_chunk_id)
 | |
| 
 | |
|             # 2. Check if the backward input is ready.
 | |
|             if backward_async_communicator is not None:
 | |
|                 output_obj_grad = backward_async_communicator.wait_and_receive()
 | |
| 
 | |
|                 if backward_async_communicator.need_receive:
 | |
|                     self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
 | |
| 
 | |
|             # 3. Send the forward outputs and receive the forward inputs from the previous rank.
 | |
| 
 | |
|             # Check if it is the last model chunk of the last pipeline stage, no need to send forward output.
 | |
|             gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
 | |
|             if gpc.is_pipeline_last_stage():
 | |
|                 output_obj = None
 | |
| 
 | |
|             # Check if it needs to receive the results from the previous rank.
 | |
|             next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
 | |
|             with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
 | |
|                 if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
 | |
|                     input_obj_shape = None
 | |
|                 else:
 | |
|                     input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
 | |
| 
 | |
|             forward_async_communicator = comm.AsynCommunicator(
 | |
|                 output_obj,
 | |
|                 input_obj_shape,
 | |
|                 self.dtype,
 | |
|                 self.scatter_gather_tensors,
 | |
|                 forward=True,
 | |
|             )
 | |
|             forward_async_communicator.start()
 | |
| 
 | |
|             # 5. Backward pass.
 | |
| 
 | |
|             input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
 | |
| 
 | |
|             input_obj = forward_async_communicator.wait_and_receive()
 | |
|             if forward_async_communicator.need_receive:
 | |
|                 self._input_objs[next_forward_chunk_id].append(input_obj)
 | |
| 
 | |
|             # 6. Send the backward output and receive the backward input for the next iteration.
 | |
|             gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
 | |
|             if gpc.is_pipeline_first_stage():
 | |
|                 input_obj_grad = None
 | |
| 
 | |
|             next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
 | |
|             with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
 | |
|                 if gpc.is_pipeline_last_stage():
 | |
|                     output_obj_shape = None
 | |
|                 else:
 | |
|                     output_obj_shape = self._output_obj_shapes[next_backward_chunk_id]
 | |
| 
 | |
|             backward_async_communicator = comm.AsynCommunicator(
 | |
|                 input_obj_grad,
 | |
|                 output_obj_shape,
 | |
|                 self.dtype,
 | |
|                 self.scatter_gather_tensors,
 | |
|                 forward=False,
 | |
|             )
 | |
|             backward_async_communicator.start()
 | |
| 
 | |
|         if all_warmup_microsteps:
 | |
|             if not gpc.is_pipeline_last_stage():
 | |
|                 self._output_obj_grads[self._num_chunks - 1].append(
 | |
|                     comm.recv_backward(
 | |
|                         self._output_obj_shapes[self._num_chunks - 1],
 | |
|                         dtype=self.dtype,
 | |
|                         scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                     )
 | |
|                 )
 | |
|             else:
 | |
|                 self._output_obj_grads[self._num_chunks - 1].append(None)
 | |
|         else:
 | |
|             output_obj_grad = backward_async_communicator.wait_and_receive()
 | |
|             if backward_async_communicator.need_receive:
 | |
|                 backward_chunk_id = self._get_chunk_by_microbatch(num_1f1b_micropairs, backward=True)
 | |
|                 self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
 | |
| 
 | |
|     def _run_1f1b_loop_without_overlap(
 | |
|         self,
 | |
|         engine: Engine,
 | |
|         num_warmup_microsteps: int,
 | |
|         num_1f1b_micropairs: int,
 | |
|         all_warmup_microsteps: bool = False,
 | |
|     ) -> None:
 | |
|         """
 | |
|         Run the 1F1B loop without overlap.
 | |
| 
 | |
|         The 1F1B loop without overlap consists of the following steps:
 | |
|         1. Perform the forward pass.
 | |
|         2. Perform the backward pass.
 | |
|         3. Send the forward output of this iteration to the next stage, and send the backward output of this iteration
 | |
|            to the previous stage,
 | |
|         and receive the forward and backward inputs for the next iteration.
 | |
| 
 | |
|         Args:
 | |
|             engine (Engine): The engine to use for computation.
 | |
|             num_warmup_microsteps (int): The number of warmup microsteps.
 | |
|             num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
 | |
|             all_warmup_microsteps (bool, optional): Whether to run all warmup microsteps. Defaults to False.
 | |
|         """
 | |
|         for k in range(num_1f1b_micropairs):
 | |
|             # Forward pass.
 | |
|             forward_microstep_id = k + num_warmup_microsteps
 | |
|             forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
 | |
|             output_obj = self._forward_step(engine, forward_chunk_id)
 | |
| 
 | |
|             # Backward pass.
 | |
|             backward_microstep_id = k
 | |
|             backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
 | |
|             input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
 | |
| 
 | |
|             # Send output_obj and input_obj_grad, receive input_obj
 | |
|             # and output_obj_grad.
 | |
| 
 | |
|             # Determine if current stage has anything to send in either direction,
 | |
|             # otherwise set obj to None.
 | |
|             gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
 | |
|             if gpc.is_pipeline_last_stage():
 | |
|                 output_obj = None
 | |
| 
 | |
|             gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
 | |
|             if gpc.is_pipeline_first_stage():
 | |
|                 input_obj_grad = None
 | |
| 
 | |
|             # Determine if peers are sending, and where in data structure to put
 | |
|             # received tensors.
 | |
|             next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
 | |
|             with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
 | |
|                 if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
 | |
|                     recv_prev = False
 | |
|                 else:
 | |
|                     recv_prev = True
 | |
| 
 | |
|             next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
 | |
|             with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
 | |
|                 if gpc.is_pipeline_last_stage():
 | |
|                     recv_next = False
 | |
|                 else:
 | |
|                     recv_next = True
 | |
| 
 | |
|             input_shape = self._input_obj_shapes[next_forward_chunk_id] if recv_prev else None
 | |
|             output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
 | |
| 
 | |
|             # Communicate objs.
 | |
|             input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
 | |
|                 output_obj,
 | |
|                 input_obj_grad,
 | |
|                 input_shape,
 | |
|                 output_shape,
 | |
|                 dtype=self.dtype,
 | |
|                 scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|             )
 | |
| 
 | |
|             # Put input_obj and output_obj_grad in data structures in the
 | |
|             # right location.
 | |
|             if recv_prev:
 | |
|                 self._input_objs[next_forward_chunk_id].append(input_obj)
 | |
|             if recv_next:
 | |
|                 self._output_obj_grads[next_backward_chunk_id].append(output_obj_grad)
 | |
| 
 | |
|         # receive necessary data for next cooldown loop
 | |
|         if all_warmup_microsteps:
 | |
|             if not gpc.is_pipeline_last_stage():
 | |
|                 self._output_obj_grads[self._num_chunks - 1].append(
 | |
|                     comm.recv_backward(
 | |
|                         self._output_obj_shapes[self._num_chunks - 1],
 | |
|                         dtype=self.dtype,
 | |
|                         scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                     )
 | |
|                 )
 | |
|             else:
 | |
|                 self._output_obj_grads[self._num_chunks - 1].append(None)
 | |
| 
 | |
|     def _run_cooldown_loop(self, engine: Engine, num_microsteps: int, num_1f1b_micropairs: int) -> None:
 | |
|         """
 | |
|         Run the cooldown loop.
 | |
| 
 | |
|         The cooldown loop consists of the following steps:
 | |
|         1. Perform the backward step.
 | |
|         2. Send the backward output to the next stage and receive inputs for next backward.
 | |
| 
 | |
|         Args:
 | |
|             engine (Engine): The engine to use for computation.
 | |
|             num_microsteps (int): The total number of microsteps.
 | |
|             num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
 | |
|         """
 | |
|         for k in range(num_1f1b_micropairs, num_microsteps):
 | |
|             chunk_id = self._get_chunk_by_microbatch(k, backward=True)
 | |
| 
 | |
|             input_obj_grad = self._backward_step(engine, chunk_id, k)
 | |
| 
 | |
|             next_backward_chunk_id = self._get_chunk_by_microbatch(k + 1, backward=True)
 | |
| 
 | |
|             if k != (num_microsteps - 1) and not (
 | |
|                 gpc.is_pipeline_last_stage(ignore_virtual=True) and next_backward_chunk_id == (self._num_chunks - 1)
 | |
|             ):
 | |
|                 output_shape = self._output_obj_shapes[next_backward_chunk_id]
 | |
|             else:
 | |
|                 output_shape = None
 | |
| 
 | |
|             self._output_obj_grads[next_backward_chunk_id].append(
 | |
|                 comm.send_backward_recv_backward(
 | |
|                     input_obj_grad,
 | |
|                     output_shape,
 | |
|                     dtype=self.dtype,
 | |
|                     scatter_gather_tensors=self.scatter_gather_tensors,
 | |
|                 )
 | |
|             )
 | |
| 
 | |
|     def _forward_only_step(self, engine: Engine):
 | |
|         num_microsteps = self.num_microbatches * self._num_chunks
 | |
|         num_warmup_microsteps = num_microsteps
 | |
| 
 | |
|         self._run_warmup_loop(
 | |
|             engine,
 | |
|             num_microsteps,
 | |
|             num_warmup_microsteps,
 | |
|             receive_extra_backward=False,
 | |
|             forward_only=True,
 | |
|         )
 | |
| 
 | |
|     def _forward_backward_step(self, engine: Engine):
 | |
|         # Compute number of warmup and remaining microbatches.
 | |
|         all_warmup_microsteps = False
 | |
|         num_microsteps = self.num_microbatches * self._num_chunks
 | |
| 
 | |
|         # Run all forward passes and then all backward passes if number of
 | |
|         # microbatches is just the number of pipeline stages.
 | |
|         # Otherwise, perform (num_chunks-1)*pipeline_parallel_size on
 | |
|         # all workers, followed by more microbatches after depending on
 | |
|         # stage ID (more forward passes for earlier stages, later stages can
 | |
|         # immediately start with 1F1B).
 | |
|         if self.num_microbatches == self._pp_size:
 | |
|             num_warmup_steps = num_microsteps
 | |
|             all_warmup_microsteps = True
 | |
|         else:
 | |
|             num_warmup_steps = (self._pp_size - self._pp_rank - 1) * 2
 | |
|             num_warmup_steps += (self._num_chunks - 1) * self._pp_size
 | |
|             num_warmup_steps = min(num_warmup_steps, num_microsteps)
 | |
|         num_1f1b_micropairs = num_microsteps - num_warmup_steps
 | |
| 
 | |
|         # We usually need to prepare an extra backward data for the 1F1B stage when the WarmUp stage ends,
 | |
|         # because the 1F1B stage typically performs one forward and backward pass together,
 | |
|         # except in the following cases:
 | |
|         receive_extra_backward = not (
 | |
|             all_warmup_microsteps  # Only warmup microsteps
 | |
|             or gpc.is_pipeline_last_stage(ignore_virtual=True)  # The rank is the last pipeline stage
 | |
|         )
 | |
| 
 | |
|         # 1. Warmup
 | |
|         self._run_warmup_loop(
 | |
|             engine,
 | |
|             num_microsteps,
 | |
|             num_warmup_steps,
 | |
|             receive_extra_backward=receive_extra_backward,
 | |
|         )
 | |
| 
 | |
|         # 2. 1F1B
 | |
|         self._run_1f1b_loop(
 | |
|             engine,
 | |
|             num_warmup_steps,
 | |
|             num_1f1b_micropairs=num_1f1b_micropairs,
 | |
|             all_warmup_microsteps=all_warmup_microsteps,
 | |
|         )
 | |
| 
 | |
|         # 3. Cooldown
 | |
|         self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
 | |
| 
 | |
|     def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
 | |
|         """Run interleaved 1F1B schedule (model split into model chunks), with
 | |
|         communication between pipeline stages as needed.
 | |
| 
 | |
|         Args:
 | |
|             engine (colossalai.engine.Engine): Colossalai engine for training and inference.
 | |
|             data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
 | |
|             forward_only (bool, optional):
 | |
|                 Whether run forward step only. Default is false. If true, no backward will be run.
 | |
|             return_loss (bool, optional): Whether returns the loss value. Default is true.
 | |
|             return_output_label (bool, optional): If False, the output and label won't be returned.
 | |
| 
 | |
|         Returns:
 | |
|             Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
 | |
|                 The loss would be returned only in the last stage.
 | |
|         """
 | |
|         assert (
 | |
|             forward_only or return_loss
 | |
|         ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
 | |
| 
 | |
|         self.load_batch(engine, data_iter)
 | |
| 
 | |
|         if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
 | |
|             self._accum_loss = torch.zeros(1, device=get_current_device())
 | |
|         if return_output_label:
 | |
|             self._return_tensors = []
 | |
| 
 | |
|         if forward_only:
 | |
|             self._forward_only_step(engine)
 | |
|         else:
 | |
|             self._forward_backward_step(engine)
 | |
| 
 | |
|         if return_output_label and len(self._return_tensors) > 0:
 | |
|             output, label = pack_return_tensors(self._return_tensors)
 | |
|         else:
 | |
|             output, label = (None, None)
 | |
|         accum_loss = self._accum_loss
 | |
| 
 | |
|         self._clear_state()
 | |
| 
 | |
|         return output, label, accum_loss
 |