From e7f3bed2d36c5406e9a9ab92438be46a5f9258d7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Tue, 21 Mar 2023 17:39:30 +0800 Subject: [PATCH] [booster] added the plugin base and torch ddp plugin (#3180) * [booster] added the plugin base and torch ddp plugin * polish code * polish code * polish code --- colossalai/booster/booster.py | 88 +++++++---- colossalai/booster/plugin.py | 46 ------ colossalai/booster/plugin/__init__.py | 4 + colossalai/booster/plugin/plugin_base.py | 51 ++++++ colossalai/booster/plugin/torch_ddp_plugin.py | 147 ++++++++++++++++++ tests/test_booster/test_accelerator.py | 22 ++- .../test_mixed_precision/test_fp16_torch.py | 21 ++- .../test_plugin/test_torch_ddp_plugin.py | 85 ++++++++++ 8 files changed, 378 insertions(+), 86 deletions(-) delete mode 100644 colossalai/booster/plugin.py create mode 100644 colossalai/booster/plugin/__init__.py create mode 100644 colossalai/booster/plugin/plugin_base.py create mode 100644 colossalai/booster/plugin/torch_ddp_plugin.py create mode 100644 tests/test_booster/test_plugin/test_torch_ddp_plugin.py diff --git a/colossalai/booster/booster.py b/colossalai/booster/booster.py index 7d7f21ca6..230c65a9e 100644 --- a/colossalai/booster/booster.py +++ b/colossalai/booster/booster.py @@ -1,9 +1,9 @@ +import warnings from contextlib import contextmanager -from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn -from torch import Tensor from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader @@ -55,27 +55,43 @@ class Booster: device: str = 'cuda', mixed_precision: Union[MixedPrecision, str] = None, plugin: Optional[Plugin] = None) -> None: - # TODO(FrankLeeeee): add plugin control logic - # if self.plugin is not None and self.plugin.control_accelerator: - # ... - # create acclerator - self.acceleartor = Accelerator(device) - self.acceleartor.set_default_device() + if plugin is not None: + assert isinstance( + plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.' + self.plugin = plugin - # validate and set precision - if isinstance(MixedPrecision, str): - # the user will take the default arguments for amp training - self.mixed_precision = mixed_precision_factory(mixed_precision) - elif isinstance(mixed_precision, MixedPrecision): - # the user can customize the arguments by passing the precision object - self.mixed_precision = mixed_precision + # set accelerator + if self.plugin and self.plugin.control_device: + self.accelerator = None + warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.') else: - raise ValueError( - f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' - ) + self.accelerator = Accelerator(device) - def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler, - dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: + # set precision + if mixed_precision is None or (self.plugin and self.plugin.control_precision): + self.mixed_precision = None + warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.') + else: + # validate and set precision + if isinstance(MixedPrecision, str): + # the user will take the default arguments for amp training + self.mixed_precision = mixed_precision_factory(mixed_precision) + elif isinstance(mixed_precision, MixedPrecision): + # the user can customize the arguments by passing the precision object + self.mixed_precision = mixed_precision + else: + raise ValueError( + f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.' + ) + + def boost( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: """ Boost the model, optimizer, criterion, lr_scheduler, and dataloader. @@ -83,22 +99,25 @@ class Booster: model (nn.Module): The model to be boosted. optimizer (Optimizer): The optimizer to be boosted. criterion (Callable): The criterion to be boosted. - lr_scheduler (LRScheduler): The lr_scheduler to be boosted. dataloader (DataLoader): The dataloader to be boosted. + lr_scheduler (LRScheduler): The lr_scheduler to be boosted. """ - # TODO(FrankLeeeee): add plugin control logic - # if self.plugin is not None and self.plugin.control_accelerator: - # ... - model = self.acceleartor.configure_model(model) - # TODO(FrankLeeeee): consider multi-model and multi-optimizer case - # TODO(lsg): Add plugin control logic - # e.g. - # if self.plugin is not None and self.plugin.control_boost: - # ... + # TODO(FrankLeeeee): consider multi-dataloader case # transform model for mixed precision - model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) - return model, optimizer, criterion, lr_scheduler, dataloader + if self.plugin: + model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure( + model, optimizer, criterion, dataloader, lr_scheduler) + + if self.plugin and not self.plugin.control_device: + # transform model for accelerator + model = self.accelerator.configure(model) + + if self.mixed_precision and self.plugin and not self.plugin.control_precision: + # transform model for mixed precision + model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion) + + return model, optimizer, criterion, dataloader, lr_scheduler def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None: # TODO: implement this method with plugin @@ -117,8 +136,9 @@ class Booster: pass def no_sync(self, model: nn.Module) -> contextmanager: - # TODO: implement this method - pass + assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.' + assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' + return self.plugin.no_sync(model) def save(self, obj: Union[nn.Module, Optimizer, LRScheduler], diff --git a/colossalai/booster/plugin.py b/colossalai/booster/plugin.py deleted file mode 100644 index 32e0a7bde..000000000 --- a/colossalai/booster/plugin.py +++ /dev/null @@ -1,46 +0,0 @@ -from typing import List, Tuple - -import torch -import torch.nn as nn -from torch.optim import Optimizer -from torch.utils.data import DataLoader - -from colossalai.device.device_mesh import DeviceMesh - -__all__ = ['Plugin'] - - -class Plugin: - - @property - def supported_devices(self) -> List[torch.device]: - pass - - @property - def supported_precisions(self) -> List[str]: - pass - - @property - def control_precision(self) -> bool: - pass - - @property - def control_device(self) -> bool: - pass - - @property - def support_no_sync(self) -> bool: - pass - - def setup_model(self, model: nn.Module, device_mesh_pool: DeviceMesh) -> nn.Module: - pass - - def setup_optimizer(self, optimizer: Optimizer) -> Optimizer: - pass - - def setup_dataloader(self, dataloader: DataLoader) -> DataLoader: - pass - - @property - def device_mesh_shape(self) -> List[Tuple[int, ...]]: - pass diff --git a/colossalai/booster/plugin/__init__.py b/colossalai/booster/plugin/__init__.py new file mode 100644 index 000000000..3328fe2b9 --- /dev/null +++ b/colossalai/booster/plugin/__init__.py @@ -0,0 +1,4 @@ +from .plugin_base import Plugin +from .torch_ddp_plugin import TorchDDPPlugin + +__all__ = ['Plugin', 'TorchDDPPlugin'] diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py new file mode 100644 index 000000000..3c347cb42 --- /dev/null +++ b/colossalai/booster/plugin/plugin_base.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Callable, List, Tuple, Union + +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader + +from colossalai.booster.interface import OptimizerWrapper + +__all__ = ['Plugin'] + + +class Plugin(ABC): + + @property + @abstractmethod + def supported_devices(self) -> List[str]: + pass + + @property + @abstractmethod + def supported_precisions(self) -> List[str]: + pass + + @property + @abstractmethod + def control_precision(self) -> bool: + pass + + @property + @abstractmethod + def control_device(self) -> bool: + pass + + @property + @abstractmethod + def support_no_sync(self) -> bool: + pass + + @abstractmethod + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + # implement this method + pass diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py new file mode 100644 index 000000000..07d6be8c7 --- /dev/null +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -0,0 +1,147 @@ +import random +from typing import Callable, List, Tuple, Union + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler as LRScheduler +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from colossalai.booster.interface import OptimizerWrapper + +from .plugin_base import Plugin + +__all__ = ['TorchDDPPlugin'] + + +class TorchDDPPlugin(Plugin): + """ + Plugin for PyTorch DDP. + + Example: + >>> from colossalai.booster import Booster + >>> from colossalai.booster.plugin import TorchDDPPlugin + >>> + >>> model, train_dataset, optimizer, criterion = ... + >>> plugin = TorchDDPPlugin() + + >>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) + >>> booster = Booster(plugin=plugin) + >>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) + + Args: + broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True. + bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25. + find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False. + check_reduction (bool, optional): Whether to check reduction. Defaults to False. + gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False. + static_graph (bool, optional): Whether to use static graph. Defaults to False. + """ + + def __init__(self, + broadcast_buffers: bool = True, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False) -> None: + + assert dist.is_initialized( + ), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment' + self.rank = dist.get_rank() + self.world_size = dist.get_world_size() + self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph) + + def support_no_sync(self) -> bool: + return True + + def control_precision(self) -> bool: + return False + + def supported_precisions(self) -> List[str]: + return ['fp16', 'fp16_apex', 'bf16', 'fp8'] + + def control_device(self) -> bool: + return True + + def supported_devices(self) -> List[str]: + return ['cuda'] + + def prepare_train_dataloader(self, + dataset, + batch_size, + shuffle=False, + seed=1024, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + Note: + 1. Evaluation datasets should not be passed to this function. + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader(dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs) + + def configure( + self, + model: nn.Module, + optimizer: Optimizer, + criterion: Callable = None, + dataloader: DataLoader = None, + lr_scheduler: LRScheduler = None, + ) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]: + # cast model to cuda + model = model.cuda() + + # wrap the model with PyTorch DDP + model = DDP(model, **self.ddp_kwargs) + + if not isinstance(optimizer, OptimizerWrapper): + optimizer = OptimizerWrapper(optimizer) + + return model, optimizer, criterion, dataloader, lr_scheduler diff --git a/tests/test_booster/test_accelerator.py b/tests/test_booster/test_accelerator.py index 4bfa3fd06..6958a87e2 100644 --- a/tests/test_booster/test_accelerator.py +++ b/tests/test_booster/test_accelerator.py @@ -1,13 +1,27 @@ -import pytest +from functools import partial + +import torch.multiprocessing as mp import torch.nn as nn -from torchvision.models import resnet18 from colossalai.booster.accelerator import Accelerator +from colossalai.testing import parameterize, rerun_if_address_is_in_use -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -def test_accelerator(device): +@parameterize('device', ['cpu', 'cuda']) +def run_accelerator(device): acceleartor = Accelerator(device) model = nn.Linear(8, 8) model = acceleartor.configure_model(model) assert next(model.parameters()).device.type == device + del model, acceleartor + + +def run_dist(rank): + run_accelerator() + + +@rerun_if_address_is_in_use() +def test_accelerator(): + world_size = 1 + run_func = partial(run_dist) + mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 98d00cd2c..bacf29014 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -1,12 +1,21 @@ +from functools import partial + import torch +import torch.multiprocessing as mp from torch.optim import Adam +import colossalai from colossalai.booster.mixed_precision import FP16TorchMixedPrecision +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port from tests.kit.model_zoo import model_zoo -def test_torch_amp(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): +def run_torch_amp(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + sub_model_zoo = model_zoo.get_sub_registry('timm') + for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip if name == 'dlrm_interactionarch': continue @@ -27,3 +36,11 @@ def test_torch_amp(): optimizer.backward(loss) optimizer.clip_grad_by_norm(1.0) optimizer.step() + del model, optimizer, criterion, data, output, mixed_precision + + +@rerun_if_address_is_in_use() +def test_torch_ddp_plugin(): + world_size = 1 + run_func = partial(run_torch_amp, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py new file mode 100644 index 000000000..58aef54c4 --- /dev/null +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -0,0 +1,85 @@ +from functools import partial + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import SGD + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.interface import OptimizerWrapper +from colossalai.booster.plugin import TorchDDPPlugin +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from tests.kit.model_zoo import model_zoo + + +def check_torch_ddp_plugin(): + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + if name == 'dlrm_interactionarch': + continue + + model = model_fn() + optimizer = SGD(model.parameters(), lr=1e-3) + criterion = lambda x: x.mean() + data = data_gen_fn() + + data = { + k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items() + } + + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + assert isinstance(model, DDP) + assert isinstance(optimizer, OptimizerWrapper) + + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.clip_grad_by_norm(1.0) + optimizer.step() + + +def check_dataloader_sharding(): + plugin = TorchDDPPlugin() + + # create a custom dasetset with 0 to 10 + dataset = torch.utils.data.TensorDataset(torch.arange(0, 10)) + train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2) + + # get the first batch of data + batch = next(iter(train_dataloader))[0].cuda() + is_rank_0 = dist.get_rank() == 0 + + if is_rank_0: + batch_to_compare = batch.clone() + else: + batch_to_compare = batch + # pass to the rank 1 value to rank 0 + dist.broadcast(batch_to_compare, src=1) + + # compare on rank 0 + if is_rank_0: + assert not torch.equal(batch, + batch_to_compare), 'Same number was found across ranks but expected it to be different' + + +def run_dist(rank, world_size, port): + # init dist env + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + check_dataloader_sharding() + check_torch_ddp_plugin() + + +@rerun_if_address_is_in_use() +def test_torch_ddp_plugin(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size)