[engin/schedule] use p2p_v2 to recontruct pipeline_schedule (#1408)

* support p2p communication with any type of object | pass test

* reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test

* [communication] add p2p_v2.py to support communication with List[Any]

* Delete _pipeline_schedule_v2.py

* Delete test_cifar_with_data_pipeline_tensor_v2.py

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* Delete p2p_v2.py

* Delete test_boardcast_send_recv_v2.py

* Delete test_object_list_p2p_v2.py

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [communication] remove print code

* [communication] remove print code

* [engin/schedule] shorten the running time of testing file to prevent cancelling in CI
pull/1439/head
Kirigaya Kazuto 2022-08-12 11:33:26 +08:00 committed by GitHub
parent ae1b58cd16
commit e9460b45c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 292 additions and 0 deletions

View File

@ -0,0 +1,181 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Tuple, Iterable
from colossalai import engine
import colossalai.communication.p2p_v2 as comm
import torch.cuda
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils.cuda import get_current_device
from ._pipeline_schedule import PipelineSchedule
def pack_return_tensors(return_tensors):
output, label = tuple(zip(*return_tensors))
if isinstance(output[0], torch.Tensor):
output = torch.cat(output, dim=0)
elif isinstance(output[0], (list, tuple)):
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
else:
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
if isinstance(label[0], torch.Tensor):
label = torch.cat(label, dim=0)
else:
merged_label = {k: [] for k in label[0].keys()}
for d in label:
for k, v in d.items():
merged_label[k].append(v)
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
return output, label
class PipelineScheduleV2(PipelineSchedule):
"""Derived class of PipelineSchedule, the only difference is that
forward_backward_step is reconstructed with p2p_v2
Args:
num_microbatches (int): The number of microbatches.
data_process_func (Callable, optional):
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
Example:
# this shows an example of customized data_process_func
def data_process_func(stage_output, dataloader_output):
output1, output2 = stage_output
item1, item2, item3 = dataloader_output
# assume item2 is not needed
data = (output1, output2, item1)
label = item3
return data, label
"""
def forward_backward_step(self,
engine: engine.Engine,
data_iter: Iterable,
forward_only=False,
return_loss=True,
return_output_label=True) -> Tuple[torch.Tensor]:
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
forward_only (bool, optional):
Whether run forward step only. Default is false. If true, no backward will be run.
return_loss (bool, optional): Whether returns the loss value. Default is true.
return_output_label (bool, optional): If False, the output and label won't be returned.
Returns:
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch(data_iter)
# num_warmup_microbatches is the step when not all the processers are working
num_warmup_microbatches = \
(gpc.get_world_size(ParallelMode.PIPELINE)
- gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
# local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
if not forward_only:
input_objs = []
output_objs = []
return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = comm.recv_forward()
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
comm.send_forward(output_obj)
if not forward_only:
input_objs.append(input_obj)
output_objs.append(output_obj)
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_obj = comm.recv_forward()
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = (i == (num_microbatches_remaining - 1))
output_obj = self._forward_step(engine,
input_obj,
return_tensors,
return_output_label=return_output_label,
accum_loss=accum_loss)
if forward_only:
comm.send_forward(output_obj)
if not last_iteration:
input_obj = comm.recv_forward()
else:
# TODO adjust here
comm.send_forward(output_obj)
output_obj_grad = comm.recv_backward()
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
output_objs.append(output_obj)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
comm.send_backward(input_obj_grad)
else:
input_obj = comm.recv_forward()
comm.send_backward(input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_warmup_microbatches):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = comm.recv_backward()
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
comm.send_backward(input_obj_grad)
if len(return_tensors) > 0:
output, label = pack_return_tensors(return_tensors)
return output, label, accum_loss
else:
return None, None, accum_loss

View File

@ -0,0 +1,111 @@
import os
from functools import partial
from pathlib import Path
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.amp import AMP_TYPE
from colossalai.trainer import Trainer, hooks
from colossalai.context import ParallelMode
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext
from colossalai.logging import disable_existing_loggers
from torchvision.datasets import CIFAR10
from torchvision import transforms
from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2
disable_existing_loggers()
BATCH_SIZE = 4
NUM_EPOCHS = 10
WARMUP_EPOCHS = 5
CONFIG = dict(NUM_MICRO_BATCHES=2,
parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')),
fp16=dict(mode=AMP_TYPE.NAIVE),
gradient_accumulation=2)
def run_trainer(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
disable_existing_loggers()
# get logger
logger = get_dist_logger()
pipelinable = PipelinableContext()
try:
from titans.model.vit import vit_tiny_patch4_32
except ImportError:
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
return
with pipelinable:
model = vit_tiny_patch4_32()
pipelinable.to_layer_list()
pipelinable.policy = "uniform"
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
# craete dataloaders
root = Path(os.environ['DATA'])
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
# create loss function
criterion = CrossEntropyLoss(label_smoothing=0.1)
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
# create lr scheduler
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
# intiailize
engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES)
logger = get_dist_logger()
trainer = Trainer(engine=engine, logger=logger)
hook_list = [
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
]
trainer.fit(train_dataloader=train_dataloader,
max_steps=2,
epochs=NUM_EPOCHS,
hooks=hook_list,
display_progress=True)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_hybrid_parallel():
world_size = 2
run_func = partial(run_trainer, world_size=world_size, port=free_port())
disable_existing_loggers()
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_hybrid_parallel()