mirror of https://github.com/hpcaitech/ColossalAI
[engin/schedule] use p2p_v2 to recontruct pipeline_schedule (#1408)
* support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [communication] add p2p_v2.py to support communication with List[Any] * Delete _pipeline_schedule_v2.py * Delete test_cifar_with_data_pipeline_tensor_v2.py * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * Delete p2p_v2.py * Delete test_boardcast_send_recv_v2.py * Delete test_object_list_p2p_v2.py * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [communication] remove print code * [communication] remove print code * [engin/schedule] shorten the running time of testing file to prevent cancelling in CIpull/1439/head
parent
ae1b58cd16
commit
e9460b45c8
|
@ -0,0 +1,181 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Tuple, Iterable
|
||||
|
||||
from colossalai import engine
|
||||
import colossalai.communication.p2p_v2 as comm
|
||||
import torch.cuda
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._pipeline_schedule import PipelineSchedule
|
||||
|
||||
|
||||
def pack_return_tensors(return_tensors):
|
||||
output, label = tuple(zip(*return_tensors))
|
||||
if isinstance(output[0], torch.Tensor):
|
||||
output = torch.cat(output, dim=0)
|
||||
elif isinstance(output[0], (list, tuple)):
|
||||
output = tuple(torch.cat(tensors, dim=0) for tensors in zip(*output))
|
||||
else:
|
||||
raise TypeError(f'Output of model must be tensor or list/tuple of tensors')
|
||||
if isinstance(label[0], torch.Tensor):
|
||||
label = torch.cat(label, dim=0)
|
||||
else:
|
||||
merged_label = {k: [] for k in label[0].keys()}
|
||||
for d in label:
|
||||
for k, v in d.items():
|
||||
merged_label[k].append(v)
|
||||
label = {k: torch.cat(v, dim=0) for k, v in merged_label.items()}
|
||||
return output, label
|
||||
|
||||
|
||||
class PipelineScheduleV2(PipelineSchedule):
|
||||
"""Derived class of PipelineSchedule, the only difference is that
|
||||
forward_backward_step is reconstructed with p2p_v2
|
||||
|
||||
Args:
|
||||
num_microbatches (int): The number of microbatches.
|
||||
data_process_func (Callable, optional):
|
||||
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
||||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||
scatter_gather_tensors (bool, optional):
|
||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||
|
||||
Example:
|
||||
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(stage_output, dataloader_output):
|
||||
output1, output2 = stage_output
|
||||
item1, item2, item3 = dataloader_output
|
||||
|
||||
# assume item2 is not needed
|
||||
data = (output1, output2, item1)
|
||||
label = item3
|
||||
return data, label
|
||||
|
||||
"""
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine: engine.Engine,
|
||||
data_iter: Iterable,
|
||||
forward_only=False,
|
||||
return_loss=True,
|
||||
return_output_label=True) -> Tuple[torch.Tensor]:
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
||||
Args:
|
||||
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
|
||||
data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
|
||||
forward_only (bool, optional):
|
||||
Whether run forward step only. Default is false. If true, no backward will be run.
|
||||
return_loss (bool, optional): Whether returns the loss value. Default is true.
|
||||
return_output_label (bool, optional): If False, the output and label won't be returned.
|
||||
|
||||
Returns:
|
||||
Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
|
||||
"""
|
||||
|
||||
assert forward_only or return_loss, \
|
||||
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
|
||||
self.load_batch(data_iter)
|
||||
|
||||
# num_warmup_microbatches is the step when not all the processers are working
|
||||
num_warmup_microbatches = \
|
||||
(gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
- gpc.get_local_rank(ParallelMode.PIPELINE) - 1)
|
||||
num_warmup_microbatches = min(num_warmup_microbatches, self.num_microbatches)
|
||||
num_microbatches_remaining = self.num_microbatches - num_warmup_microbatches
|
||||
|
||||
# Input, output tensors only need to be saved when doing backward passes
|
||||
input_objs = None
|
||||
output_objs = None
|
||||
# local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
if not forward_only:
|
||||
input_objs = []
|
||||
output_objs = []
|
||||
return_tensors = []
|
||||
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
|
||||
accum_loss = torch.zeros(1, device=get_current_device())
|
||||
else:
|
||||
accum_loss = None
|
||||
|
||||
# Run warmup forward passes.
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = comm.recv_forward()
|
||||
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
|
||||
comm.send_forward(output_obj)
|
||||
|
||||
if not forward_only:
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Before running 1F1B, need to receive first forward tensor.
|
||||
# If all microbatches are run in warmup / cooldown phase, then no need to
|
||||
# receive this tensor here.
|
||||
if num_microbatches_remaining > 0:
|
||||
input_obj = comm.recv_forward()
|
||||
|
||||
# Run 1F1B in steady state.
|
||||
for i in range(num_microbatches_remaining):
|
||||
last_iteration = (i == (num_microbatches_remaining - 1))
|
||||
|
||||
output_obj = self._forward_step(engine,
|
||||
input_obj,
|
||||
return_tensors,
|
||||
return_output_label=return_output_label,
|
||||
accum_loss=accum_loss)
|
||||
if forward_only:
|
||||
comm.send_forward(output_obj)
|
||||
|
||||
if not last_iteration:
|
||||
input_obj = comm.recv_forward()
|
||||
|
||||
else:
|
||||
# TODO adjust here
|
||||
comm.send_forward(output_obj)
|
||||
output_obj_grad = comm.recv_backward()
|
||||
|
||||
# Add input_obj and output_obj to end of list.
|
||||
input_objs.append(input_obj)
|
||||
output_objs.append(output_obj)
|
||||
|
||||
# Pop output_obj and output_obj from the start of the list for
|
||||
# the backward pass.
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
|
||||
if last_iteration:
|
||||
input_obj = None
|
||||
comm.send_backward(input_obj_grad)
|
||||
else:
|
||||
input_obj = comm.recv_forward()
|
||||
comm.send_backward(input_obj_grad)
|
||||
|
||||
# Run cooldown backward passes.
|
||||
if not forward_only:
|
||||
for i in range(num_warmup_microbatches):
|
||||
input_obj = input_objs.pop(0)
|
||||
output_obj = output_objs.pop(0)
|
||||
|
||||
output_obj_grad = comm.recv_backward()
|
||||
input_obj_grad = self._backward_step(engine, input_obj, output_obj, output_obj_grad)
|
||||
comm.send_backward(input_obj_grad)
|
||||
|
||||
if len(return_tensors) > 0:
|
||||
output, label = pack_return_tensors(return_tensors)
|
||||
return output, label, accum_loss
|
||||
else:
|
||||
return None, None, accum_loss
|
|
@ -0,0 +1,111 @@
|
|||
import os
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.trainer import Trainer, hooks
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.testing import rerun_if_address_is_in_use, skip_if_not_enough_gpus
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn import CrossEntropyLoss
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||
from colossalai.utils import get_dataloader
|
||||
from colossalai.pipeline.pipelinable import PipelinableContext
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision import transforms
|
||||
|
||||
from colossalai.engine.schedule._pipeline_schedule_v2 import PipelineScheduleV2
|
||||
|
||||
disable_existing_loggers()
|
||||
BATCH_SIZE = 4
|
||||
NUM_EPOCHS = 10
|
||||
WARMUP_EPOCHS = 5
|
||||
CONFIG = dict(NUM_MICRO_BATCHES=2,
|
||||
parallel=dict(pipeline=2, tensor=dict(size=1, mode='1d')),
|
||||
fp16=dict(mode=AMP_TYPE.NAIVE),
|
||||
gradient_accumulation=2)
|
||||
|
||||
|
||||
def run_trainer(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
disable_existing_loggers()
|
||||
# get logger
|
||||
logger = get_dist_logger()
|
||||
|
||||
pipelinable = PipelinableContext()
|
||||
try:
|
||||
from titans.model.vit import vit_tiny_patch4_32
|
||||
except ImportError:
|
||||
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
|
||||
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
|
||||
return
|
||||
with pipelinable:
|
||||
model = vit_tiny_patch4_32()
|
||||
pipelinable.to_layer_list()
|
||||
pipelinable.policy = "uniform"
|
||||
model = pipelinable.partition(1, gpc.pipeline_parallel_size, gpc.get_local_rank(ParallelMode.PIPELINE))
|
||||
|
||||
# craete dataloaders
|
||||
root = Path(os.environ['DATA'])
|
||||
transform_train = transforms.Compose([
|
||||
transforms.RandomCrop(32, padding=4, pad_if_needed=True),
|
||||
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
||||
])
|
||||
train_dataset = CIFAR10(root=root, train=True, download=True, transform=transform_train)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset, shuffle=True, batch_size=BATCH_SIZE, pin_memory=True)
|
||||
|
||||
# create loss function
|
||||
criterion = CrossEntropyLoss(label_smoothing=0.1)
|
||||
|
||||
# create optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0)
|
||||
|
||||
# create lr scheduler
|
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=NUM_EPOCHS, warmup_steps=WARMUP_EPOCHS)
|
||||
|
||||
# intiailize
|
||||
engine, train_dataloader, *_ = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
engine._schedule = PipelineScheduleV2(num_microbatches=gpc.config.NUM_MICRO_BATCHES)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
trainer = Trainer(engine=engine, logger=logger)
|
||||
|
||||
hook_list = [
|
||||
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||
]
|
||||
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
max_steps=2,
|
||||
epochs=NUM_EPOCHS,
|
||||
hooks=hook_list,
|
||||
display_progress=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_hybrid_parallel():
|
||||
world_size = 2
|
||||
run_func = partial(run_trainer, world_size=world_size, port=free_port())
|
||||
disable_existing_loggers()
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hybrid_parallel()
|
Loading…
Reference in New Issue