From e9460b45c8583b943dc10cbc3ec7cfe7a9c385c9 Mon Sep 17 00:00:00 2001 From: Kirigaya Kazuto <59416203+LSTM-Kirigaya@users.noreply.github.com> Date: Fri, 12 Aug 2022 11:33:26 +0800 Subject: [PATCH] [engin/schedule] use p2p_v2 to recontruct pipeline_schedule (#1408) * support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [communication] add p2p_v2.py to support communication with List[Any] * Delete _pipeline_schedule_v2.py * Delete test_cifar_with_data_pipeline_tensor_v2.py * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * Delete p2p_v2.py * Delete test_boardcast_send_recv_v2.py * Delete test_object_list_p2p_v2.py * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [communication] remove print code * [communication] remove print code * [engin/schedule] shorten the running time of testing file to prevent cancelling in CI --- .../engine/schedule/_pipeline_schedule_v2.py | 181 ++++++++++++++++++ ...test_cifar_with_data_pipeline_tensor_v2.py | 111 +++++++++++ 2 files changed, 292 insertions(+) create mode 100644 colossalai/engine/schedule/_pipeline_schedule_v2.py create mode 100644 tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py diff --git a/colossalai/engine/schedule/_pipeline_schedule_v2.py b/colossalai/engine/schedule/_pipeline_schedule_v2.py new file mode 100644 index 000000000..50a87aafa --- /dev/null +++ b/colossalai/engine/schedule/_pipeline_schedule_v2.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from typing import Tuple, Iterable + +from colossalai import engine +import colossalai.communication.p2p_v2 as comm +import torch.cuda +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.utils.cuda import get_current_device + +from ._pipeline_schedule import PipelineSchedule + + +def pack_return_tensors(return_tensors): + output, label = tuple(zip(*return_tensors)) + if isinstance(output[0], torch.Tensor): + output = torch.cat(output, dim=0) + elif isinstance(output[0], (list, tuple)): + output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output)) + else: + raise TypeError(f'Output of model must be tensor or list/tuple of tensors') + if isinstance(label[0], torch.Tensor): + label = torch.cat(label, dim=0) + else: + merged_label = {k: [] for k in label[0].keys()} + for d in label: + for k, v in d.items(): + merged_label[k].append(v) + label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()} + return output, label + + +class PipelineScheduleV2(PipelineSchedule): + """Derived class of PipelineSchedule, the only difference is that + forward_backward_step is reconstructed with p2p_v2 + + Args: + num_microbatches (int): The number of microbatches. + data_process_func (Callable, optional): + The preprocessing function which receives a batch of data, and it will be executed in `load_batch`. + tensor_shape (torch.Size, optional): Specified shape in pipeline communication. + scatter_gather_tensors (bool, optional): + If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization. + + Example: + + # this shows an example of customized data_process_func + def data_process_func(stage_output, dataloader_output): + output1, output2 = stage_output + item1, item2, item3 = dataloader_output + + # assume item2 is not needed + data = (output1, output2, item1) + label = item3 + return data, label + + """ + + def forward_backward_step(self, + engine: engine.Engine, + data_iter: Iterable, + forward_only=False, + return_loss=True, + return_output_label=True) -> Tuple[torch.Tensor]: + """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. + Returns a tuple with losses if the last stage, an empty tuple otherwise. + + Args: + engine (colossalai.engine.Engine): Colossalai engine for training and inference. + data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader). + forward_only (bool, optional): + Whether run forward step only. Default is false. If true, no backward will be run. + return_loss (bool, optional): Whether returns the loss value. Default is true. + return_output_label (bool, optional): If False, the output and label won't be returned. + + Returns: + Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. + """ + + assert forward_only or return_loss, \ + 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' + self.load_batch(data_iter) + + # num_warmup_microbatches is the step when not all the processers are working + num_warmup_microbatches = \ + (gpc.get_world_size(ParallelMode.PIPELINE) + - gpc.get_local_rank(ParallelMode.PIPELINE) - 1) + num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches) + num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches + + # Input, output tensors only need to be saved when doing backward passes + input_objs = None + output_objs = None + # local_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + if not forward_only: + input_objs = [] + output_objs = [] + return_tensors = [] + if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): + accum_loss = torch.zeros(1, device=get_current_device()) + else: + accum_loss = None + + # Run warmup forward passes. + for i in range(num_warmup_microbatches): + input_obj = comm.recv_forward() + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + + comm.send_forward(output_obj) + + if not forward_only: + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Before running 1F1B, need to receive first forward tensor. + # If all microbatches are run in warmup / cooldown phase, then no need to + # receive this tensor here. + if num_microbatches_remaining > 0: + input_obj = comm.recv_forward() + + # Run 1F1B in steady state. + for i in range(num_microbatches_remaining): + last_iteration = (i == (num_microbatches_remaining - 1)) + + output_obj = self._forward_step(engine, + input_obj, + return_tensors, + return_output_label=return_output_label, + accum_loss=accum_loss) + if forward_only: + comm.send_forward(output_obj) + + if not last_iteration: + input_obj = comm.recv_forward() + + else: + # TODO adjust here + comm.send_forward(output_obj) + output_obj_grad = comm.recv_backward() + + # Add input_obj and output_obj to end of list. + input_objs.append(input_obj) + output_objs.append(output_obj) + + # Pop output_obj and output_obj from the start of the list for + # the backward pass. + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + + if last_iteration: + input_obj = None + comm.send_backward(input_obj_grad) + else: + input_obj = comm.recv_forward() + comm.send_backward(input_obj_grad) + + # Run cooldown backward passes. + if not forward_only: + for i in range(num_warmup_microbatches): + input_obj = input_objs.pop(0) + output_obj = output_objs.pop(0) + + output_obj_grad = comm.recv_backward() + input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad) + comm.send_backward(input_obj_grad) + + if len(return_tensors) > 0: + output, label = pack_return_tensors(return_tensors) + return output, label, accum_loss + else: + return None, None, accum_loss diff --git a/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py new file mode 100644 index 000000000..2bafe0f7e --- /dev/null +++ b/tests/test_data_pipeline_tensor_parallel/test_cifar_with_data_pipeline_tensor_v2.py @@ -0,0 +1,111 @@ +import os + +from functools import partial +from pathlib import Path + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.amp import AMP_TYPE +from colossalai.trainer import Trainer, hooks +from colossalai.context import ParallelMode +from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.logging import get_dist_logger +from colossalai.nn import CrossEntropyLoss +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +from colossalai.utils import get_dataloader +from colossalai.pipeline.pipelinable import PipelinableContext +from colossalai.logging import disable_existing_loggers +from torchvision.datasets import CIFAR10 +from torchvision import transforms + +from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2 + +disable_existing_loggers() +BATCH_SIZE = 4 +NUM_EPOCHS = 10 +WARMUP_EPOCHS = 5 +CONFIG = dict(NUM_MICRO_BATCHES=2, + parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')), + fp16=dict(mode=AMP_TYPE.NAIVE), + gradient_accumulation=2) + + +def run_trainer(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + disable_existing_loggers() + # get logger + logger = get_dist_logger() + + pipelinable = PipelinableContext() + try: + from titans.model.vit import vit_tiny_patch4_32 + except ImportError: + logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed') + logger.warning('please install titan from https://github.com/hpcaitech/Titans') + return + with pipelinable: + model = vit_tiny_patch4_32() + pipelinable.to_layer_list() + pipelinable.policy = "uniform" + model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE)) + + # craete dataloaders + root = Path(os.environ['DATA']) + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4, pad_if_needed=True), + transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train) + train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True) + + # create loss function + criterion = CrossEntropyLoss(label_smoothing=0.1) + + # create optimizer + optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0) + + # create lr scheduler + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS) + + # intiailize + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) + + engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES) + + logger = get_dist_logger() + + trainer = Trainer(engine=engine, logger=logger) + + hook_list = [ + hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False), + ] + + trainer.fit(train_dataloader=train_dataloader, + max_steps=2, + epochs=NUM_EPOCHS, + hooks=hook_list, + display_progress=True) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_hybrid_parallel(): + world_size = 2 + run_func = partial(run_trainer, world_size=world_size, port=free_port()) + disable_existing_loggers() + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_hybrid_parallel()