ColossalAI/tests/test_engine/test_pipeline/debug_schedule.py

233 lines
7.8 KiB
Python

# referenced from Megatron and used to testify communication
import os.path as osp
import pytest
import torch
from torch.utils.data import DataLoader
from colossalai.builder import ModelInitializer, build_dataset, build_optimizer, build_loss
from colossalai.communication import p2p as p2p_communication
from colossalai.communication.utils import send_tensor_meta, recv_tensor_meta
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
from colossalai.utils import print_rank_0, get_current_device
NUM_BATCH = 128
NUM_MICRO = 6
def get_num_microbatches():
return NUM_MICRO
def to_cuda(data):
if isinstance(data, (tuple, list)):
data = data[0].to(get_current_device())
else:
data = data.to(get_current_device())
return data
def step_func(loss):
def _step_func(input_tensor, model):
output = model(input_tensor)
if isinstance(output, (tuple, list)):
if len(output) > 1:
raise NotImplementedError("Multiple output!!!")
else:
output = output[0]
return output, loss
return _step_func
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
if input_tensor is None:
data, label = data_iterator.next()
input_tensor = to_cuda(data)
output_tensor, loss_func = forward_step_func(input_tensor, model)
if gpc.is_last_rank(ParallelMode.PIPELINE):
data, label = data_iterator.next()
label = to_cuda(label)
output_tensor = loss_func(output_tensor, label) / get_num_microbatches()
losses_reduced.append(output_tensor)
return output_tensor
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
# Retain the grad on the input_tensor.
if input_tensor is not None:
input_tensor.retain_grad()
# Backward pass.
torch.autograd.backward(output_tensor, grad_tensors=output_tensor_grad)
# Collect the grad of the input_tensor.
input_tensor_grad = None
if input_tensor is not None:
input_tensor_grad = input_tensor.grad
return input_tensor_grad
def forward_backward_pipelining_without_interleaving(forward_step_func, data_iterator,
model, optimizer, forward_only):
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
# Compute number of warmup microbatches.
num_microbatches = get_num_microbatches()
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE) -
gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_tensors = None
output_tensors = None
if not forward_only:
input_tensors = []
output_tensors = []
losses_reduced = []
# Used for tensor meta information communication
ft_shape = None
bt_shape = None
fs_checker = True
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = p2p_communication.recv_forward(ft_shape)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if not gpc.is_last_rank(ParallelMode.PIPELINE):
bt_shape = output_tensor.shape
fs_checker = send_tensor_meta(output_tensor, fs_checker)
p2p_communication.send_forward(output_tensor)
if not forward_only:
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
if not gpc.is_first_rank(ParallelMode.PIPELINE):
ft_shape = recv_tensor_meta(ft_shape)
input_tensor = p2p_communication.recv_forward(ft_shape)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor)
if not last_iteration:
input_tensor = p2p_communication.recv_forward(ft_shape)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, bt_shape)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Pop input_tensor and output_tensor from the start of the list for
# the backward pass.
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(input_tensor_grad, ft_shape)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(bt_shape)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad)
return losses_reduced
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_vit.py')
@pytest.mark.skip(reason="This is only for debugging purpose, please ignore this test")
@pytest.mark.dist
def test_schedule():
initialize(CONFIG_PATH)
# build model
model = ModelInitializer(gpc.config.model, 1).model_initialize()
print_rank_0('model is created')
# keep the same sampler for all process
torch.manual_seed(1331)
dataset = build_dataset(gpc.config.data.dataset)
dataloader = DataLoader(dataset=dataset, **gpc.config.data.dataloader)
print_rank_0('train data is created')
# build optimizer and loss
optim = build_optimizer(gpc.config.optimizer, model)
loss = build_loss(gpc.config.loss)
print_rank_0('optim and loss is created')
forward_backward_pipelining_without_interleaving(
step_func(loss),
iter(dataloader),
model,
optim,
False
)
gpc.destroy()
print_rank_0('training finished')
if __name__ == '__main__':
test_schedule()