[booster] added the plugin base and torch ddp plugin (#3180)

* [booster] added the plugin base and torch ddp plugin

* polish code

* polish code

* polish code
pull/3197/head
Frank Lee 2023-03-21 17:39:30 +08:00 committed by GitHub
parent e5f668f280
commit e7f3bed2d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 378 additions and 86 deletions

View File

@ -1,9 +1,9 @@
import warnings
from contextlib import contextmanager
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
@ -55,27 +55,43 @@ class Booster:
device: str = 'cuda',
mixed_precision: Union[MixedPrecision, str] = None,
plugin: Optional[Plugin] = None) -> None:
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
# create acclerator
self.acceleartor = Accelerator(device)
self.acceleartor.set_default_device()
if plugin is not None:
assert isinstance(
plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
self.plugin = plugin
# validate and set precision
if isinstance(MixedPrecision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
# the user can customize the arguments by passing the precision object
self.mixed_precision = mixed_precision
# set accelerator
if self.plugin and self.plugin.control_device:
self.accelerator = None
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
else:
raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)
self.accelerator = Accelerator(device)
def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler,
dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
# set precision
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
self.mixed_precision = None
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
else:
# validate and set precision
if isinstance(MixedPrecision, str):
# the user will take the default arguments for amp training
self.mixed_precision = mixed_precision_factory(mixed_precision)
elif isinstance(mixed_precision, MixedPrecision):
# the user can customize the arguments by passing the precision object
self.mixed_precision = mixed_precision
else:
raise ValueError(
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
)
def boost(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
"""
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
@ -83,22 +99,25 @@ class Booster:
model (nn.Module): The model to be boosted.
optimizer (Optimizer): The optimizer to be boosted.
criterion (Callable): The criterion to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
dataloader (DataLoader): The dataloader to be boosted.
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
"""
# TODO(FrankLeeeee): add plugin control logic
# if self.plugin is not None and self.plugin.control_accelerator:
# ...
model = self.acceleartor.configure_model(model)
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
# TODO(lsg): Add plugin control logic
# e.g.
# if self.plugin is not None and self.plugin.control_boost:
# ...
# TODO(FrankLeeeee): consider multi-dataloader case
# transform model for mixed precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
return model, optimizer, criterion, lr_scheduler, dataloader
if self.plugin:
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
model, optimizer, criterion, dataloader, lr_scheduler)
if self.plugin and not self.plugin.control_device:
# transform model for accelerator
model = self.accelerator.configure(model)
if self.mixed_precision and self.plugin and not self.plugin.control_precision:
# transform model for mixed precision
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
return model, optimizer, criterion, dataloader, lr_scheduler
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
# TODO: implement this method with plugin
@ -117,8 +136,9 @@ class Booster:
pass
def no_sync(self, model: nn.Module) -> contextmanager:
# TODO: implement this method
pass
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
def save(self,
obj: Union[nn.Module, Optimizer, LRScheduler],

View File

@ -1,46 +0,0 @@
from typing import List, Tuple
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from colossalai.device.device_mesh import DeviceMesh
__all__ = ['Plugin']
class Plugin:
@property
def supported_devices(self) -> List[torch.device]:
pass
@property
def supported_precisions(self) -> List[str]:
pass
@property
def control_precision(self) -> bool:
pass
@property
def control_device(self) -> bool:
pass
@property
def support_no_sync(self) -> bool:
pass
def setup_model(self, model: nn.Module, device_mesh_pool: DeviceMesh) -> nn.Module:
pass
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
pass
def setup_dataloader(self, dataloader: DataLoader) -> DataLoader:
pass
@property
def device_mesh_shape(self) -> List[Tuple[int, ...]]:
pass

View File

@ -0,0 +1,4 @@
from .plugin_base import Plugin
from .torch_ddp_plugin import TorchDDPPlugin
__all__ = ['Plugin', 'TorchDDPPlugin']

View File

@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.booster.interface import OptimizerWrapper
__all__ = ['Plugin']
class Plugin(ABC):
@property
@abstractmethod
def supported_devices(self) -> List[str]:
pass
@property
@abstractmethod
def supported_precisions(self) -> List[str]:
pass
@property
@abstractmethod
def control_precision(self) -> bool:
pass
@property
@abstractmethod
def control_device(self) -> bool:
pass
@property
@abstractmethod
def support_no_sync(self) -> bool:
pass
@abstractmethod
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# implement this method
pass

View File

@ -0,0 +1,147 @@
import random
from typing import Callable, List, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.booster.interface import OptimizerWrapper
from .plugin_base import Plugin
__all__ = ['TorchDDPPlugin']
class TorchDDPPlugin(Plugin):
"""
Plugin for PyTorch DDP.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchDDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False.
check_reduction (bool, optional): Whether to check reduction. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
static_graph (bool, optional): Whether to use static graph. Defaults to False.
"""
def __init__(self,
broadcast_buffers: bool = True,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False) -> None:
assert dist.is_initialized(
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
def support_no_sync(self) -> bool:
return True
def control_precision(self) -> bool:
return False
def supported_precisions(self) -> List[str]:
return ['fp16', 'fp16_apex', 'bf16', 'fp8']
def control_device(self) -> bool:
return True
def supported_devices(self) -> List[str]:
return ['cuda']
def prepare_train_dataloader(self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
r"""
Prepare a dataloader for distributed training. The dataloader will be wrapped by
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
Note:
1. Evaluation datasets should not be passed to this function.
Args:
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
seed (int, optional): Random worker seed for sampling, defaults to 1024.
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
is not divisible by the batch size. If False and the size of dataset is not divisible by
the batch size, then the last batch will be smaller, defaults to False.
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
Returns:
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
"""
_kwargs = kwargs.copy()
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
def configure(
self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None,
dataloader: DataLoader = None,
lr_scheduler: LRScheduler = None,
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
# cast model to cuda
model = model.cuda()
# wrap the model with PyTorch DDP
model = DDP(model, **self.ddp_kwargs)
if not isinstance(optimizer, OptimizerWrapper):
optimizer = OptimizerWrapper(optimizer)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,13 +1,27 @@
import pytest
from functools import partial
import torch.multiprocessing as mp
import torch.nn as nn
from torchvision.models import resnet18
from colossalai.booster.accelerator import Accelerator
from colossalai.testing import parameterize, rerun_if_address_is_in_use
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_accelerator(device):
@parameterize('device', ['cpu', 'cuda'])
def run_accelerator(device):
acceleartor = Accelerator(device)
model = nn.Linear(8, 8)
model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device
del model, acceleartor
def run_dist(rank):
run_accelerator()
@rerun_if_address_is_in_use()
def test_accelerator():
world_size = 1
run_func = partial(run_dist)
mp.spawn(run_func, nprocs=world_size)

View File

@ -1,12 +1,21 @@
from functools import partial
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
import colossalai
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo
def test_torch_amp():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
def run_torch_amp(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip
if name == 'dlrm_interactionarch':
continue
@ -27,3 +36,11 @@ def test_torch_amp():
optimizer.backward(loss)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
del model, optimizer, criterion, data, output, mixed_precision
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 1
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@ -0,0 +1,85 @@
from functools import partial
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
import colossalai
from colossalai.booster import Booster
from colossalai.booster.interface import OptimizerWrapper
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo
def check_torch_ddp_plugin():
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
if name == 'dlrm_interactionarch':
continue
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model, DDP)
assert isinstance(optimizer, OptimizerWrapper)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
def check_dataloader_sharding():
plugin = TorchDDPPlugin()
# create a custom dasetset with 0 to 10
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
# get the first batch of data
batch = next(iter(train_dataloader))[0].cuda()
is_rank_0 = dist.get_rank() == 0
if is_rank_0:
batch_to_compare = batch.clone()
else:
batch_to_compare = batch
# pass to the rank 1 value to rank 0
dist.broadcast(batch_to_compare, src=1)
# compare on rank 0
if is_rank_0:
assert not torch.equal(batch,
batch_to_compare), 'Same number was found across ranks but expected it to be different'
def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_dataloader_sharding()
check_torch_ddp_plugin()
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)