mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
177 lines
7.0 KiB
177 lines
7.0 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from typing import Iterable, Tuple
|
|
|
|
import torch.cuda
|
|
|
|
import colossalai.legacy.communication.p2p_v2 as comm
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.legacy.context.parallel_mode import ParallelMode
|
|
from colossalai.legacy.core import global_context as gpc
|
|
from colossalai.legacy.engine import Engine
|
|
|
|
from ._pipeline_schedule import PipelineSchedule
|
|
|
|
|
|
def pack_return_tensors(return_tensors):
|
|
output, label = tuple(zip(*return_tensors))
|
|
if isinstance(output[0], torch.Tensor):
|
|
output = torch.cat(output, dim=0)
|
|
elif isinstance(output[0], (list, tuple)):
|
|
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
|
else:
|
|
raise TypeError(f"Output of model must be tensor or list/tuple of tensors")
|
|
if isinstance(label[0], torch.Tensor):
|
|
label = torch.cat(label, dim=0)
|
|
else:
|
|
merged_label = {k: [] for k in label[0].keys()}
|
|
for d in label:
|
|
for k, v in d.items():
|
|
merged_label[k].append(v)
|
|
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
|
|
return output, label
|
|
|
|
|
|
class PipelineScheduleV2(PipelineSchedule):
|
|
"""Derived class of PipelineSchedule, the only difference is that
|
|
forward_backward_step is reconstructed with p2p_v2
|
|
|
|
Args:
|
|
num_microbatches (int): The number of microbatches.
|
|
data_process_func (Callable, optional):
|
|
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
|
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
|
scatter_gather_tensors (bool, optional):
|
|
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
|
|
|
Example:
|
|
|
|
# this shows an example of customized data_process_func
|
|
def data_process_func(stage_output, dataloader_output):
|
|
output1, output2 = stage_output
|
|
item1, item2, item3 = dataloader_output
|
|
|
|
# assume item2 is not needed
|
|
data = (output1, output2, item1)
|
|
label = item3
|
|
return data, label
|
|
|
|
"""
|
|
|
|
def forward_backward_step(
|
|
self, engine: Engine, data_iter: Iterable, forward_only=False, return_loss=True, return_output_label=True
|
|
) -> Tuple[torch.Tensor]:
|
|
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
|
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
|
|
|
Args:
|
|
engine (colossalai.legacy.engine.Engine): Colossalai engine for training and inference.
|
|
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
|
forward_only (bool, optional):
|
|
Whether run forward step only. Default is false. If true, no backward will be run.
|
|
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
|
return_output_label (bool, optional): If False, the output and label won't be returned.
|
|
|
|
Returns:
|
|
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
|
"""
|
|
|
|
assert (
|
|
forward_only or return_loss
|
|
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
|
self.load_batch(data_iter)
|
|
|
|
# num_warmup_microbatches is the step when not all the processes are working
|
|
num_warmup_microbatches = (
|
|
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
|
)
|
|
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
|
|
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
|
|
|
# Input, output tensors only need to be saved when doing backward passes
|
|
input_objs = None
|
|
output_objs = None
|
|
# local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
|
|
if not forward_only:
|
|
input_objs = []
|
|
output_objs = []
|
|
return_tensors = []
|
|
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
|
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
|
|
else:
|
|
accum_loss = None
|
|
|
|
# Run warmup forward passes.
|
|
for i in range(num_warmup_microbatches):
|
|
input_obj = comm.recv_forward()
|
|
|
|
output_obj = self._forward_step(
|
|
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
|
|
)
|
|
|
|
comm.send_forward(output_obj)
|
|
|
|
if not forward_only:
|
|
input_objs.append(input_obj)
|
|
output_objs.append(output_obj)
|
|
|
|
# Before running 1F1B, need to receive first forward tensor.
|
|
# If all microbatches are run in warmup / cooldown phase, then no need to
|
|
# receive this tensor here.
|
|
if num_microbatches_remaining > 0:
|
|
input_obj = comm.recv_forward()
|
|
|
|
# Run 1F1B in steady state.
|
|
for i in range(num_microbatches_remaining):
|
|
last_iteration = i == (num_microbatches_remaining - 1)
|
|
|
|
output_obj = self._forward_step(
|
|
engine, input_obj, return_tensors, return_output_label=return_output_label, accum_loss=accum_loss
|
|
)
|
|
if forward_only:
|
|
comm.send_forward(output_obj)
|
|
|
|
if not last_iteration:
|
|
input_obj = comm.recv_forward()
|
|
|
|
else:
|
|
# TODO adjust here
|
|
comm.send_forward(output_obj)
|
|
output_obj_grad = comm.recv_backward()
|
|
|
|
# Add input_obj and output_obj to end of list.
|
|
input_objs.append(input_obj)
|
|
output_objs.append(output_obj)
|
|
|
|
# Pop output_obj and output_obj from the start of the list for
|
|
# the backward pass.
|
|
input_obj = input_objs.pop(0)
|
|
output_obj = output_objs.pop(0)
|
|
|
|
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
|
|
|
if last_iteration:
|
|
input_obj = None
|
|
comm.send_backward(input_obj_grad)
|
|
else:
|
|
input_obj = comm.recv_forward()
|
|
comm.send_backward(input_obj_grad)
|
|
|
|
# Run cooldown backward passes.
|
|
if not forward_only:
|
|
for i in range(num_warmup_microbatches):
|
|
input_obj = input_objs.pop(0)
|
|
output_obj = output_objs.pop(0)
|
|
|
|
output_obj_grad = comm.recv_backward()
|
|
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
|
comm.send_backward(input_obj_grad)
|
|
|
|
if len(return_tensors) > 0:
|
|
output, label = pack_return_tensors(return_tensors)
|
|
return output, label, accum_loss
|
|
else:
|
|
return None, None, accum_loss
|