InternLM/internlm/core/scheduler/pipeline_scheduler.py

1293 lines
56 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple, Union
import torch.cuda
import internlm.core.communication as comm
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_current_device, move_to_device
from internlm.utils.logger import get_logger
from .base_scheduler import BaseScheduler, SchedulerHook
logger = get_logger(__file__)
def get_tensor_shape():
if hasattr(gpc.config, "TENSOR_SHAPE"):
return gpc.config.TENSOR_SHAPE
if not gpc.is_initialized(ParallelMode.PIPELINE):
return None
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
if gpc.config.model.use_flash_attn:
if gpc.config.parallel.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.data["micro_bsz"],
gpc.config.SEQ_LEN,
gpc.config.HIDDEN_SIZE,
)
return tensor_shape
else:
return None
def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors))
if isinstance(output[0], torch.Tensor):
output = torch.cat(output, dim=0)
elif isinstance(output[0], (list, tuple)):
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
else:
raise TypeError("Output of model must be tensor or list/tuple of tensors")
if isinstance(label[0], torch.Tensor):
label = torch.cat(label, dim=0)
else:
merged_label = {k: [] for k in label[0].keys()}
for d in label:
for k, v in d.items():
merged_label[k].append(v)
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
return output, label
@contextmanager
def switch_virtual_pipeline_parallel_rank(rank):
prev_rank = gpc.virtual_pipeline_parallel_rank
try:
gpc.set_virtual_pipeline_parallel_rank(rank)
yield
finally:
gpc.set_virtual_pipeline_parallel_rank(prev_rank)
@contextmanager
def switch_optimizer_grad_sync_skip_mode(optimizer, skip: bool = True):
prev_mode = optimizer.skip_grad_reduce
try:
optimizer.skip_grad_reduce = skip
yield
finally:
optimizer.skip_grad_reduce = prev_mode
class PipelineScheduler(BaseScheduler):
"""
A helper schedule class for pipeline parallelism running environment.
It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`.
Args:
num_microbatches (int): The number of microbatches.
dtype (torch.dtype): Type of data. torch.float by default.
data_process_func (Callable, optional):
The post processing function which receives a micro batch of data, and it will be executed
in `load_micro_batch`.
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
scheduler_hooks (Optional[List[SchedulerHook]], optional): List of scheduler hooks.
"""
def __init__(
self,
num_microbatches: int,
dtype: torch.dtype = torch.float,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
):
assert num_microbatches > 0, f"expected num_microbatches to be larger then 1, but got {num_microbatches}"
assert not isinstance(
tensor_shape, int
), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
super().__init__(data_process_func=data_process_func)
self.num_microbatches = num_microbatches
self.dtype = dtype
self._hooks = scheduler_hooks
self._tensor_shape = (
tensor_shape if tensor_shape is None or isinstance(tensor_shape, torch.Size) else torch.Size(tensor_shape)
)
self.scatter_gather_tensors = (
scatter_gather_tensors
and gpc.is_initialized(ParallelMode.TENSOR)
and gpc.get_world_size(ParallelMode.TENSOR) > 1
)
if gpc.config.parallel.sequence_parallel:
self.scatter_gather_tensors = False
# cache for the batch data
self.batch_data = None
@property
def tensor_shape(self) -> torch.Size:
return self._tensor_shape
@tensor_shape.setter
def tensor_shape(self, tensor_shape: torch.Size):
self._tensor_shape = tensor_shape
def pre_processing(self, engine):
types = set()
for param in engine.model.parameters():
types.add(param.dtype)
assert len(types) == 1, f"Mixed types of parameter detected, {types}"
self.dtype = types.pop()
@staticmethod
def _call_engine(engine, data): # pylint: disable=W0237
if data is None:
return None
if isinstance(data, torch.Tensor):
return engine(data)
elif isinstance(data, (list, tuple)):
return engine(*data)
elif isinstance(data, dict):
stage_output = data.pop("stage_output", None)
if stage_output is None:
return engine(**data)
elif isinstance(stage_output, torch.Tensor):
return engine(stage_output, **data)
elif isinstance(stage_output, (tuple, list)):
return engine(*stage_output, **data)
else:
raise TypeError(
f"Expected stage_output to be of type torch.Tensor, list, or tuple, "
f"but got {type(stage_output)}"
)
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
def load_batch(self, engine, data_iter):
# Pipeline schedule just puts data in memory
batch_data, batch_size = engine.load_batch(data_iter, to_gpu=False)
assert batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
self.microbatch_offset = 0
self.batch_size = batch_size
self.batch_data, self.batch_label = batch_data
self.microbatch_size = self.batch_size // self.num_microbatches
def load_micro_batch(self):
micro_batch_data, micro_batch_label = self._load_micro_batch(
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
)
if self.data_process_func:
micro_batch_data["input_ids"] = self.data_process_func(
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
)
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
micro_batch_data.pop("cu_seqlens")
micro_batch_data.pop("indexes")
micro_batch_data["label"] = micro_batch_label
self.microbatch_offset += self.microbatch_size
return move_to_device(micro_batch_data)
def _get_data_label_for_current_step(self, stage_output, micro_batch_data):
if isinstance(micro_batch_data, (tuple, list)):
if gpc.is_first_rank(ParallelMode.PIPELINE):
# for the first stage, we use the data from the
# dataloader output by default
data, label = micro_batch_data
else:
# for non-first stage, we use the output passed
# by the previous as the model input
data = stage_output
_, label = micro_batch_data
elif isinstance(micro_batch_data, dict):
label = micro_batch_data.pop("label", None)
data = {"stage_output": stage_output, **micro_batch_data}
return data, label
def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
for hook in self._hooks:
getattr(hook, func_name)(self, *args, **kwargs)
def _get_current_microbatch_id(self, step_id: int) -> int:
"""
Get the current microbatch ID based on the step ID.
In 1f1b scheduler, the microbatch ID is the same as the step ID,
but it is important to note that the step ID is calculated separately
for forward and backward passes.
"""
return step_id
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
"""
Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
input_obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
pipeline stage.
"""
micro_batch_data = self.load_micro_batch()
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model, data)
self._call_hooks("after_forward", output_obj)
if gpc.is_last_rank(ParallelMode.PIPELINE):
self._call_hooks("post_helper_func", output_obj, label)
if return_output_label:
return_tensors.append((output_obj, label))
if accum_loss is not None:
self._call_hooks("before_criterion", output_obj, label)
loss = self._call_engine_criterion(engine, output_obj, label)
self._call_hooks("after_criterion", loss)
loss_reduced = loss / self.num_microbatches
accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
return output_obj
def _backward_step(self, engine, step_id, input_obj, output_obj, output_obj_grad):
"""
Backward step through the passed-in output tensor. If it is the last stage, the
output_obj_grad is None, otherwise it is the gradients with respect to stage's output tensor.
Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
step_id (int): The ID of the current step.
input_obj (Union[torch.Tensor, List[torch.Tensor]]): Input tensor for this stage.
output_obj (Union[torch.Tensor, List[torch.Tensor]]): Output tensor for this stage.
output_obj_grad (Union[torch.Tensor, List[torch.Tensor]]): Gradient of output tensor for this stage.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: Gradient of input tensor.
"""
# Retain the grad on the input_obj.
if input_obj is not None:
if isinstance(input_obj, torch.Tensor):
input_obj.retain_grad()
else:
for in_tensor in input_obj:
if in_tensor is not None:
in_tensor.retain_grad()
# Backward pass.
# Only the last microbatch does syncing grad.
skip_grad_sync = self._get_current_microbatch_id(step_id) != self.num_microbatches - 1
self._call_hooks("before_backward", output_obj, output_obj_grad)
with switch_optimizer_grad_sync_skip_mode(engine.optimizer, skip_grad_sync):
if output_obj_grad is None:
engine.backward(output_obj)
else:
engine.backward_by_grad(output_obj, output_obj_grad)
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
if isinstance(input_obj, torch.Tensor):
input_obj_grad = input_obj.grad
else:
input_obj_grad = []
for in_tensor in input_obj:
input_obj_grad.append(in_tensor.grad)
self._call_hooks("after_backward", input_obj_grad)
return input_obj_grad
def _forward_only_step(self, engine, return_loss=True, return_output_label=True):
"""
This function performs forward only computation process. The scheduling of microbatches is similar to the
warmup phase, where each microbatch first receives the forward input from the previous stage, then performs
the forward computation, and finally passes the forward computation output to the next stage. There are two
special cases to note:
1. The first stage of the pipeline does not need to receive forward input; its input comes from the dataloader.
2. The last stage of the pipeline does not need to send forward output; its output is returned to the user code
for processing.
Args:
engine (colossalai.engine.Engine): internlm engine for training and inference.
return_loss (bool, optional): Whether to return the accumulated loss.
return_output_label (bool, optional): Whether to return outputs and labels.
Returns:
Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
output, label, and accumulated loss.
"""
# Input, output tensors only need to be saved when doing backward passes
return_tensors = []
accum_loss = (
torch.zeros(1, device=get_current_device())
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
need_forward_meta = self.tensor_shape is None
# Run all forward passes.
for _ in range(self.num_microbatches):
# Receive input from the previous stage
if not gpc.is_first_rank(ParallelMode.PIPELINE):
if forward_recv_shapes is None:
forward_recv_shapes = comm.recv_obj_meta()
input_obj = comm.recv_forward(
forward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
else:
input_obj = None
# Perform forward computation
output_obj = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
if need_forward_meta:
comm.send_obj_meta(output_obj)
need_forward_meta = False # send only once.
# Send the forward computation output to the next stage
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
return output, label, accum_loss
def _forward_backward_step(self, engine, return_loss=True, return_output_label=True):
"""
This function schedules the forward and backward computation of microbatches in the pipeline in a 1F1B manner.
It consists of three stages: warmup, 1F1B, and cooldown.
1. Warmup Stage:
The warmup stage performs num_warmup forward microsteps. The calculation of num_warmup is the pipeline length
minus the rank of the current pipeline minus 1. For each microstep, it receives data as input from the previous
stage, performs the forward computation, and then sends the result to the next stage.
2. 1F1B Stage:
The 1F1B stage consists of pairs of forward and backward microsteps. It performs num_1f1b_micropairs iterations,
where num_1f1b_micropairs is calculated as the total number of microbatches minus the number of microbatches in
the warmup stage. In each iteration, it first performs a forward computation, sends the result to the next
stage, receives input for the backward computation, performs the backward computation, and finally sends the
result to the previous stage to receive input for the next forward computation.
3. Cooldown Stage:
The cooldown stage performs the same number of iterations as the warmup stage. In each iteration, it receives
input for the backward computation, performs the backward computation, and finally sends the result to the
previous stage.
There are two special cases to consider:
1. The first stage of the pipeline does not need to receive forward input or send backward output. The last
stage does not need to send forward output or receive backward input.
2. Pay attention to the communication between stages and use additional communication to bridge the gap.
Args:
engine (Engine): The engine used for computation.
return_loss (bool, optional): Whether to return the accumulated loss.
return_output_label (bool, optional): Whether to return outputs and labels.
Returns:
Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None], Union[torch.Tensor, None]]:
The output, label, and accumulated loss.
"""
num_warmup_microsteps = (
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
)
num_warmup_microsteps = min(num_warmup_microsteps, self.num_microbatches)
num_1f1b_micropairs = self.num_microbatches - num_warmup_microsteps
# Input, output tensors only need to be saved when doing backward passes
input_objs = []
output_objs = []
return_tensors = []
accum_loss = (
torch.zeros(1, device=get_current_device())
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
backward_recv_shapes = None
need_forward_meta = self.tensor_shape is None
# Run warmup forward passes.
for i in range(num_warmup_microsteps):
# Receive the input from the previous stage
if not gpc.is_first_rank(ParallelMode.PIPELINE):
if forward_recv_shapes is None:
forward_recv_shapes = comm.recv_obj_meta()
input_obj = comm.recv_forward(
forward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
else:
input_obj = None
# Perform forward computation
output_obj = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
if isinstance(output_obj, torch.Tensor):
backward_recv_shapes = output_obj.shape
else:
backward_recv_shapes = [out_tensor.shape for out_tensor in output_obj]
if need_forward_meta:
comm.send_obj_meta(output_obj)
need_forward_meta = False # send only once.
# Send the output of forward computation of this pipeline stage to the next pipeline stage as input for
# forward computation
if not gpc.is_last_rank(ParallelMode.PIPELINE):
comm.send_forward(output_obj, scatter_gather_tensors=self.scatter_gather_tensors)
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_1f1b_micropairs > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
if forward_recv_shapes is None:
forward_recv_shapes = comm.recv_obj_meta(forward_recv_shapes)
input_obj = comm.recv_forward(
forward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
else:
input_obj = None
# Run 1F1B in steady state.
for i in range(num_1f1b_micropairs):
# Perform forward computation
output_obj = self._forward_step(
engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss,
)
if gpc.is_last_rank(ParallelMode.PIPELINE):
output_obj_grad = None
else:
output_obj_grad = comm.send_forward_recv_backward(
output_obj,
backward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad)
if i == (num_1f1b_micropairs - 1):
input_obj = None
if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(
input_obj_grad,
scatter_gather_tensors=self.scatter_gather_tensors,
)
else:
if gpc.is_first_rank(ParallelMode.PIPELINE):
input_obj = None
else:
input_obj = comm.send_backward_recv_forward(
input_obj_grad,
forward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
# Run cooldown backward passes.
for i in range(num_warmup_microsteps):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
output_obj_grad = comm.recv_backward(
backward_recv_shapes,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
else:
output_obj_grad = None
input_obj_grad = self._backward_step(
engine, num_1f1b_micropairs + i, input_obj, output_obj, output_obj_grad
)
if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)
output, label = pack_return_tensors(return_tensors) if len(return_tensors) > 0 else (None, None)
return output, label, accum_loss
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
# Load data first
self.load_batch(engine, data_iter)
if forward_only:
return self._forward_only_step(engine, return_loss, return_output_label)
else:
return self._forward_backward_step(engine, return_loss, return_output_label)
class InterleavedPipelineScheduler(PipelineScheduler):
"""
Interleaved Pipeline Scheduler.
"""
def __init__(
self,
num_microbatches: int,
num_chunks: int,
dtype: torch.dtype = torch.float,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False,
scheduler_hooks: Optional[List[SchedulerHook]] = None,
communication_overlap: bool = False,
):
"""A helper schedule class for pipeline parallelism running environment.
It uses interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`.
Args:
num_microbatches (int): The number of microbatches.
num_chunks (int): The number of model chunks.
dtype (torch.dtype, optional): The data type of the tensors. Default is torch.float.
data_process_func (Callable, optional):
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
scheduler_hooks (List[SchedulerHook], optional): List of scheduler hooks. Default is None.
communication_overlap (bool, optional): Whether to enable communication overlap. Default is False.
"""
assert (
num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0
), "num_microbatches must be an integer multiple of pipeline parallel world size"
assert (
isinstance(num_chunks, int) and num_chunks > 0
), f"expected num_chunks to be an integer and larger than 0, but got {num_chunks}"
super().__init__(
num_microbatches,
dtype=dtype,
data_process_func=data_process_func,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather_tensors,
scheduler_hooks=scheduler_hooks,
)
gpc.set_virtual_pipeline_parallel_size(num_chunks)
gpc.set_virtual_pipeline_parallel_rank(0)
self._num_chunks = num_chunks
self._communication_overlap = communication_overlap
# switch 1f1b loop runner function according to communication overlap
self._run_1f1b_loop = (
self._run_1f1b_loop_with_overlap if communication_overlap else self._run_1f1b_loop_without_overlap
)
# states
self._pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
self._pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
self._accum_loss = None
self._return_tensors = None
self._input_objs = [[] for _ in range(num_chunks)]
self._output_objs = [[] for _ in range(num_chunks)]
self._output_obj_grads = [[] for _ in range(num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
self._output_obj_shapes = [None for _ in range(num_chunks)]
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
@property
def tensor_shape(self) -> torch.Size:
return self._tensor_shape
@tensor_shape.setter
def tensor_shape(self, tensor_shape: torch.Size):
self._tensor_shape = tensor_shape
self._input_obj_shapes = [self._tensor_shape for _ in range(self._num_chunks)]
self._send_tensor_shape_flags = [self._tensor_shape is None for _ in range(self._num_chunks)]
def _clear_state(self) -> None:
self._accum_loss = None
self._return_tensors = None
self._input_objs = [[] for _ in range(self._num_chunks)]
self._output_objs = [[] for _ in range(self._num_chunks)]
self._output_obj_grads = [[] for _ in range(self._num_chunks)]
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
self._output_obj_shapes = [None for _ in range(self._num_chunks)]
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(self._num_chunks)]
def load_batch(self, engine, data_iter):
super().load_batch(engine, data_iter)
# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
self.microbatch_offset = [0 for _ in range(self._num_chunks)]
def load_micro_batch(self, model_chunk_id):
micro_batch_data, micro_batch_label = self._load_micro_batch(
data=self.batch_data,
label=self.batch_label,
offset=self.microbatch_offset[model_chunk_id],
micro_bsz=self.microbatch_size,
)
micro_batch_data["label"] = micro_batch_label
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return move_to_device(micro_batch_data)
def _forward_step(self, engine, chunk_id):
"""Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_obj is used.
Returns output tensor. This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
chunk_id (int): The id of model chunks.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current
pipeline stage.
"""
gpc.set_virtual_pipeline_parallel_rank(chunk_id)
if gpc.is_pipeline_first_stage() and len(self._input_objs[chunk_id]) == len(self._output_objs[chunk_id]):
self._input_objs[chunk_id].append(None)
input_obj = self._input_objs[chunk_id][-1]
micro_batch_data = self.load_micro_batch(chunk_id)
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data)
self._call_hooks("before_forward", data)
output_obj = self._call_engine(engine.model[chunk_id], data)
# Convert output_obj to fp32 when last model chunk of last stage
if gpc.is_pipeline_last_stage(ignore_virtual=False) and isinstance(engine.model[chunk_id], NaiveAMPModel):
output_obj = engine.model[chunk_id].convert_to_fp32(output_obj)
self._call_hooks("after_forward", output_obj)
if gpc.is_pipeline_last_stage():
self._call_hooks("post_helper_func", output_obj, label)
if self._return_tensors is not None:
self._return_tensors.append((output_obj, label))
if self._accum_loss is not None:
self._call_hooks("before_criterion", output_obj, label)
loss = self._call_engine_criterion(engine, output_obj, label)
self._call_hooks("after_criterion", loss)
loss_reduced = loss / self.num_microbatches
self._accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced
self._output_objs[chunk_id].append(output_obj)
return output_obj
def _backward_step(self, engine, chunk_id, step_id):
"""
Backward step for passed-in model. If it is the last stage, the input tensor
is obtained from the previous forward step, otherwise the passed-in input_obj is used.
Returns input tensor gradient. This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
chunk_id (int): The id of model chunks.
step_id (int): The current step id.
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: input tensor gradient.
"""
gpc.set_virtual_pipeline_parallel_rank(chunk_id)
if gpc.is_pipeline_last_stage() and len(self._output_obj_grads[chunk_id]) == 0:
self._output_obj_grads[chunk_id].append(None)
input_obj = self._input_objs[chunk_id].pop(0)
output_obj = self._output_objs[chunk_id].pop(0)
output_obj_grad = self._output_obj_grads[chunk_id].pop(0)
input_obj_grad = super()._backward_step(engine, step_id, input_obj, output_obj, output_obj_grad)
return input_obj_grad
def _get_chunk_by_microbatch(self, step_id: int, backward: bool = False) -> int:
"""Helper method to get the model chunk ID given the iteration number."""
microbatch_id_in_group = step_id % (self._pp_size * self._num_chunks)
chunk_id = microbatch_id_in_group // self._pp_size
if backward:
chunk_id = self._num_chunks - chunk_id - 1
return chunk_id
def _get_current_microbatch_id(self, step_id: int) -> int:
# format:
# microstep_id : 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# microbatch_id: 1 2 3 4 1 2 3 4 5 6 7 8 5 6 7 8
num_microbatch_group = step_id // (self._pp_size * self._num_chunks)
step_id_in_group = step_id % (self._pp_size * self._num_chunks)
microbatch_id = num_microbatch_group * self._pp_size + step_id_in_group % self._pp_size
return microbatch_id
def _run_warmup_loop(
self,
engine: Engine,
num_microsteps: int,
num_warmup_microsteps: int,
receive_extra_backward: bool = False,
forward_only: bool = False,
) -> None:
"""
Run the warm-up loop and prepare data for the 1F1B stage.
During the warm-up process, for each execution, it first performs a forward computation,
and then sends the computation result to the next stage.
It also receives data for the next forward computation.
Since the input for the first forward computation is not considered initially,
it needs to receive data once at the beginning.
After the warm-up is completed, we need to prepare data for the 1F1B stage.
The data preparation process should be consistent with the communication method of the 1F1B stage.
Args:
engine (Engine): The engine to run the warm-up loop.
num_microsteps (int): The total number of microsteps.
num_warmup_microsteps (int): The number of warm-up microsteps.
receive_extra_backward (bool, optional): Whether to receive extra backward input for the 1F1B stage.
Default is False.
forward_only (bool, optional): Whether to only perform forward pass. Default is False.
"""
if not gpc.is_pipeline_first_stage():
if self._input_obj_shapes[0] is None:
self._input_obj_shapes[0] = comm.recv_obj_meta(self._input_obj_shapes[0])
self._input_objs[0].append(
comm.recv_forward(
self._input_obj_shapes[0],
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
)
else:
self._input_objs[0].append(None)
for k in range(num_warmup_microsteps):
chunk_id = self._get_chunk_by_microbatch(k)
output_obj = self._forward_step(engine, chunk_id)
if forward_only:
# when forward-only, no need to save tensors for a backward pass
self._input_objs[chunk_id].pop()
self._output_objs[chunk_id].pop()
if not gpc.is_pipeline_last_stage():
if isinstance(output_obj, torch.Tensor):
self._output_obj_shapes[chunk_id] = output_obj.shape
else:
self._output_obj_shapes[chunk_id] = [out_tensor.shape for out_tensor in output_obj]
if self._send_tensor_shape_flags[chunk_id]:
comm.send_obj_meta(output_obj)
self._send_tensor_shape_flags[chunk_id] = False # send only once for each chunk.
# Determine if tensor should be received from previous stage.
next_forward_chunk_id = self._get_chunk_by_microbatch(k + 1)
with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
if not gpc.is_pipeline_first_stage() and self._input_obj_shapes[next_forward_chunk_id] is None:
self._input_obj_shapes[next_forward_chunk_id] = comm.recv_obj_meta()
if k == (num_microsteps - 1) or gpc.is_pipeline_first_stage():
input_shape = None
else:
input_shape = self._input_obj_shapes[next_forward_chunk_id]
# Don't send tensor downstream if on last stage.
if gpc.is_pipeline_last_stage():
output_obj = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k != (num_warmup_microsteps - 1) or not receive_extra_backward:
# Normal warm-up communication process, or no need to prepare backward input for the 1F1B stage
input_obj = comm.send_forward_recv_forward(
output_obj,
input_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
else:
# Receive output_obj_grad for next backward, if receive_extra_backward is True.
if self._communication_overlap:
# In this case, we should handle forward and backward communication separately, consistent with the
# overlap version of the 1F1B stage
input_obj = comm.send_forward_recv_forward(
output_obj,
input_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
output_obj_grad = comm.send_backward_recv_backward(
None, # nothing to send
self._output_obj_shapes[self._num_chunks - 1],
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
else:
# In this case, we should handle forward and backward communication together, consistent with the
# non-overlap version of the 1F1B stage
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
None, # no backward grad to send
input_shape,
self._output_obj_shapes[self._num_chunks - 1],
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
self._output_obj_grads[self._num_chunks - 1].append(output_obj_grad)
self._input_objs[next_forward_chunk_id].append(input_obj)
def _run_1f1b_loop_with_overlap(
self,
engine: Engine,
num_warmup_microsteps: int,
num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False,
) -> None:
"""
Run the 1F1B loop with overlap.
The 1F1B loop with overlap consists of the following steps:
1. Perform the forward pass.
2. Check if the backward input is ready.
3. Send the forward output and receive the forward input for the next iteration.
4. Perform the backward pass.
5. Check if the forward input is ready.
6. Send the backward output and receive the backward input for the next iteration.
Args:
engine (Engine): The engine to run the 1F1B loop.
num_warmup_microsteps (int): The number of warm-up microsteps.
num_1f1b_micropairs (int): The number of 1F1B micropairs.
all_warmup_microsteps (bool, optional): Whether to run all warm-up microsteps. Default is False.
"""
backward_async_communicator = None
# Run 1F1B in steady state.
for k in range(num_1f1b_micropairs):
forward_microstep_id = k + num_warmup_microsteps
backward_microstep_id = k
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
# 1. Forward pass.
output_obj = self._forward_step(engine, forward_chunk_id)
# 2. Check if the backward input is ready.
if backward_async_communicator is not None:
output_obj_grad = backward_async_communicator.wait_and_receive()
if backward_async_communicator.need_receive:
self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
# 3. Send the forward outputs and receive the forward inputs from the previous rank.
# Check if it is the last model chunk of the last pipeline stage, no need to send forward output.
gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
if gpc.is_pipeline_last_stage():
output_obj = None
# Check if it needs to receive the results from the previous rank.
next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
input_obj_shape = None
else:
input_obj_shape = self._input_obj_shapes[next_forward_chunk_id]
forward_async_communicator = comm.AsynCommunicator(
output_obj,
input_obj_shape,
self.dtype,
self.scatter_gather_tensors,
forward=True,
)
forward_async_communicator.start()
# 5. Backward pass.
input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
input_obj = forward_async_communicator.wait_and_receive()
if forward_async_communicator.need_receive:
self._input_objs[next_forward_chunk_id].append(input_obj)
# 6. Send the backward output and receive the backward input for the next iteration.
gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
if gpc.is_pipeline_first_stage():
input_obj_grad = None
next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
if gpc.is_pipeline_last_stage():
output_obj_shape = None
else:
output_obj_shape = self._output_obj_shapes[next_backward_chunk_id]
backward_async_communicator = comm.AsynCommunicator(
input_obj_grad,
output_obj_shape,
self.dtype,
self.scatter_gather_tensors,
forward=False,
)
backward_async_communicator.start()
if all_warmup_microsteps:
if not gpc.is_pipeline_last_stage():
self._output_obj_grads[self._num_chunks - 1].append(
comm.recv_backward(
self._output_obj_shapes[self._num_chunks - 1],
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
)
else:
self._output_obj_grads[self._num_chunks - 1].append(None)
else:
output_obj_grad = backward_async_communicator.wait_and_receive()
if backward_async_communicator.need_receive:
backward_chunk_id = self._get_chunk_by_microbatch(num_1f1b_micropairs, backward=True)
self._output_obj_grads[backward_chunk_id].append(output_obj_grad)
def _run_1f1b_loop_without_overlap(
self,
engine: Engine,
num_warmup_microsteps: int,
num_1f1b_micropairs: int,
all_warmup_microsteps: bool = False,
) -> None:
"""
Run the 1F1B loop without overlap.
The 1F1B loop without overlap consists of the following steps:
1. Perform the forward pass.
2. Perform the backward pass.
3. Send the forward output of this iteration to the next stage, and send the backward output of this iteration
to the previous stage, and receive the forward and backward inputs for the next iteration.
Args:
engine (Engine): The engine to use for computation.
num_warmup_microsteps (int): The number of warmup microsteps.
num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
all_warmup_microsteps (bool, optional): Whether to run all warmup microsteps. Defaults to False.
"""
for k in range(num_1f1b_micropairs):
# Forward pass.
forward_microstep_id = k + num_warmup_microsteps
forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id)
output_obj = self._forward_step(engine, forward_chunk_id)
# Backward pass.
backward_microstep_id = k
backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id, backward=True)
input_obj_grad = self._backward_step(engine, backward_chunk_id, backward_microstep_id)
# Send output_obj and input_obj_grad, receive input_obj
# and output_obj_grad.
# Determine if current stage has anything to send in either direction,
# otherwise set obj to None.
gpc.set_virtual_pipeline_parallel_rank(forward_chunk_id)
if gpc.is_pipeline_last_stage():
output_obj = None
gpc.set_virtual_pipeline_parallel_rank(backward_chunk_id)
if gpc.is_pipeline_first_stage():
input_obj_grad = None
# Determine if peers are sending, and where in data structure to put
# received tensors.
next_forward_chunk_id = self._get_chunk_by_microbatch(forward_microstep_id + 1)
with switch_virtual_pipeline_parallel_rank(next_forward_chunk_id):
if gpc.is_pipeline_first_stage() or k == num_1f1b_micropairs - 1:
recv_prev = False
else:
recv_prev = True
next_backward_chunk_id = self._get_chunk_by_microbatch(backward_microstep_id + 1, backward=True)
with switch_virtual_pipeline_parallel_rank(next_backward_chunk_id):
if gpc.is_pipeline_last_stage():
recv_next = False
else:
recv_next = True
input_shape = self._input_obj_shapes[next_forward_chunk_id] if recv_prev else None
output_shape = self._output_obj_shapes[next_backward_chunk_id] if recv_next else None
# Communicate objs.
input_obj, output_obj_grad = comm.send_forward_backward_recv_forward_backward(
output_obj,
input_obj_grad,
input_shape,
output_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
# Put input_obj and output_obj_grad in data structures in the
# right location.
if recv_prev:
self._input_objs[next_forward_chunk_id].append(input_obj)
if recv_next:
self._output_obj_grads[next_backward_chunk_id].append(output_obj_grad)
# receive necessary data for next cooldown loop
if all_warmup_microsteps:
if not gpc.is_pipeline_last_stage():
self._output_obj_grads[self._num_chunks - 1].append(
comm.recv_backward(
self._output_obj_shapes[self._num_chunks - 1],
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
)
else:
self._output_obj_grads[self._num_chunks - 1].append(None)
def _run_cooldown_loop(self, engine: Engine, num_microsteps: int, num_1f1b_micropairs: int) -> None:
"""
Run the cooldown loop.
The cooldown loop consists of the following steps:
1. Perform the backward step.
2. Send the backward output to the next stage and receive inputs for next backward.
Args:
engine (Engine): The engine to use for computation.
num_microsteps (int): The total number of microsteps.
num_1f1b_micropairs (int): The number of 1F1B micro-pairs.
"""
for k in range(num_1f1b_micropairs, num_microsteps):
chunk_id = self._get_chunk_by_microbatch(k, backward=True)
input_obj_grad = self._backward_step(engine, chunk_id, k)
next_backward_chunk_id = self._get_chunk_by_microbatch(k + 1, backward=True)
if k != (num_microsteps - 1) and not (
gpc.is_pipeline_last_stage(ignore_virtual=True) and next_backward_chunk_id == (self._num_chunks - 1)
):
output_shape = self._output_obj_shapes[next_backward_chunk_id]
else:
output_shape = None
self._output_obj_grads[next_backward_chunk_id].append(
comm.send_backward_recv_backward(
input_obj_grad,
output_shape,
dtype=self.dtype,
scatter_gather_tensors=self.scatter_gather_tensors,
)
)
def _forward_only_step(self, engine: Engine):
num_microsteps = self.num_microbatches * self._num_chunks
num_warmup_microsteps = num_microsteps
self._run_warmup_loop(
engine,
num_microsteps,
num_warmup_microsteps,
receive_extra_backward=False,
forward_only=True,
)
def _forward_backward_step(self, engine: Engine):
# Compute number of warmup and remaining microbatches.
all_warmup_microsteps = False
num_microsteps = self.num_microbatches * self._num_chunks
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if self.num_microbatches == self._pp_size:
num_warmup_steps = num_microsteps
all_warmup_microsteps = True
else:
num_warmup_steps = (self._pp_size - self._pp_rank - 1) * 2
num_warmup_steps += (self._num_chunks - 1) * self._pp_size
num_warmup_steps = min(num_warmup_steps, num_microsteps)
num_1f1b_micropairs = num_microsteps - num_warmup_steps
# We usually need to prepare an extra backward data for the 1F1B stage when the WarmUp stage ends,
# because the 1F1B stage typically performs one forward and backward pass together,
# except in the following cases:
receive_extra_backward = not (
all_warmup_microsteps # Only warmup microsteps
or gpc.is_pipeline_last_stage(ignore_virtual=True) # The rank is the last pipeline stage
)
# 1. Warmup
self._run_warmup_loop(
engine,
num_microsteps,
num_warmup_steps,
receive_extra_backward=receive_extra_backward,
)
# 2. 1F1B
self._run_1f1b_loop(
engine,
num_warmup_steps,
num_1f1b_micropairs=num_1f1b_micropairs,
all_warmup_microsteps=all_warmup_microsteps,
)
# 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
The loss would be returned only in the last stage.
"""
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
gpc.set_virtual_pipeline_parallel_rank(0)
self.load_batch(engine, data_iter)
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
self._accum_loss = torch.zeros(1, device=get_current_device())
if return_output_label:
self._return_tensors = []
if forward_only:
self._forward_only_step(engine)
else:
self._forward_backward_step(engine)
if return_output_label and len(self._return_tensors) > 0:
output, label = pack_return_tensors(self._return_tensors)
else:
output, label = (None, None)
accum_loss = self._accum_loss
self._clear_state()
return output, label, accum_loss