mirror of https://github.com/hpcaitech/ColossalAI
[booster] added the plugin base and torch ddp plugin (#3180)
* [booster] added the plugin base and torch ddp plugin * polish code * polish code * polish codepull/3197/head
parent
e5f668f280
commit
e7f3bed2d3
|
@ -1,9 +1,9 @@
|
||||||
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -55,27 +55,43 @@ class Booster:
|
||||||
device: str = 'cuda',
|
device: str = 'cuda',
|
||||||
mixed_precision: Union[MixedPrecision, str] = None,
|
mixed_precision: Union[MixedPrecision, str] = None,
|
||||||
plugin: Optional[Plugin] = None) -> None:
|
plugin: Optional[Plugin] = None) -> None:
|
||||||
# TODO(FrankLeeeee): add plugin control logic
|
if plugin is not None:
|
||||||
# if self.plugin is not None and self.plugin.control_accelerator:
|
assert isinstance(
|
||||||
# ...
|
plugin, Plugin), f'Expected the argument plugin to be an instance of Plugin, but got {type(plugin)}.'
|
||||||
# create acclerator
|
self.plugin = plugin
|
||||||
self.acceleartor = Accelerator(device)
|
|
||||||
self.acceleartor.set_default_device()
|
|
||||||
|
|
||||||
# validate and set precision
|
# set accelerator
|
||||||
if isinstance(MixedPrecision, str):
|
if self.plugin and self.plugin.control_device:
|
||||||
# the user will take the default arguments for amp training
|
self.accelerator = None
|
||||||
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
warnings.warn('The plugin will control the accelerator, so the device argument will be ignored.')
|
||||||
elif isinstance(mixed_precision, MixedPrecision):
|
|
||||||
# the user can customize the arguments by passing the precision object
|
|
||||||
self.mixed_precision = mixed_precision
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
self.accelerator = Accelerator(device)
|
||||||
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
|
||||||
)
|
|
||||||
|
|
||||||
def boost(self, model: nn.Module, optimizer: Optimizer, criterion: Callable, lr_scheduler: LRScheduler,
|
# set precision
|
||||||
dataloader: DataLoader) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
if mixed_precision is None or (self.plugin and self.plugin.control_precision):
|
||||||
|
self.mixed_precision = None
|
||||||
|
warnings.warn('The plugin will control the precision, so the mixed_precision argument will be ignored.')
|
||||||
|
else:
|
||||||
|
# validate and set precision
|
||||||
|
if isinstance(MixedPrecision, str):
|
||||||
|
# the user will take the default arguments for amp training
|
||||||
|
self.mixed_precision = mixed_precision_factory(mixed_precision)
|
||||||
|
elif isinstance(mixed_precision, MixedPrecision):
|
||||||
|
# the user can customize the arguments by passing the precision object
|
||||||
|
self.mixed_precision = mixed_precision
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Expected the argument mixed_precision to be a string or an instance of Precision, but got {type(mixed_precision)}.'
|
||||||
|
)
|
||||||
|
|
||||||
|
def boost(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Callable = None,
|
||||||
|
dataloader: DataLoader = None,
|
||||||
|
lr_scheduler: LRScheduler = None,
|
||||||
|
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||||
"""
|
"""
|
||||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||||
|
|
||||||
|
@ -83,22 +99,25 @@ class Booster:
|
||||||
model (nn.Module): The model to be boosted.
|
model (nn.Module): The model to be boosted.
|
||||||
optimizer (Optimizer): The optimizer to be boosted.
|
optimizer (Optimizer): The optimizer to be boosted.
|
||||||
criterion (Callable): The criterion to be boosted.
|
criterion (Callable): The criterion to be boosted.
|
||||||
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
|
|
||||||
dataloader (DataLoader): The dataloader to be boosted.
|
dataloader (DataLoader): The dataloader to be boosted.
|
||||||
|
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
|
||||||
"""
|
"""
|
||||||
# TODO(FrankLeeeee): add plugin control logic
|
|
||||||
# if self.plugin is not None and self.plugin.control_accelerator:
|
|
||||||
# ...
|
|
||||||
model = self.acceleartor.configure_model(model)
|
|
||||||
|
|
||||||
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
||||||
# TODO(lsg): Add plugin control logic
|
# TODO(FrankLeeeee): consider multi-dataloader case
|
||||||
# e.g.
|
|
||||||
# if self.plugin is not None and self.plugin.control_boost:
|
|
||||||
# ...
|
|
||||||
# transform model for mixed precision
|
# transform model for mixed precision
|
||||||
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
if self.plugin:
|
||||||
return model, optimizer, criterion, lr_scheduler, dataloader
|
model, optimizer, criterion, dataloader, lr_scheduler = self.plugin.configure(
|
||||||
|
model, optimizer, criterion, dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
if self.plugin and not self.plugin.control_device:
|
||||||
|
# transform model for accelerator
|
||||||
|
model = self.accelerator.configure(model)
|
||||||
|
|
||||||
|
if self.mixed_precision and self.plugin and not self.plugin.control_precision:
|
||||||
|
# transform model for mixed precision
|
||||||
|
model, optimizer, criterion = self.mixed_precision.configure(model, optimizer, criterion)
|
||||||
|
|
||||||
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||||
# TODO: implement this method with plugin
|
# TODO: implement this method with plugin
|
||||||
|
@ -117,8 +136,9 @@ class Booster:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module) -> contextmanager:
|
def no_sync(self, model: nn.Module) -> contextmanager:
|
||||||
# TODO: implement this method
|
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
||||||
pass
|
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||||
|
return self.plugin.no_sync(model)
|
||||||
|
|
||||||
def save(self,
|
def save(self,
|
||||||
obj: Union[nn.Module, Optimizer, LRScheduler],
|
obj: Union[nn.Module, Optimizer, LRScheduler],
|
||||||
|
|
|
@ -1,46 +0,0 @@
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
|
||||||
|
|
||||||
__all__ = ['Plugin']
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin:
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_devices(self) -> List[torch.device]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_precisions(self) -> List[str]:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def control_precision(self) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def control_device(self) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def support_no_sync(self) -> bool:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup_model(self, model: nn.Module, device_mesh_pool: DeviceMesh) -> nn.Module:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def setup_dataloader(self, dataloader: DataLoader) -> DataLoader:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device_mesh_shape(self) -> List[Tuple[int, ...]]:
|
|
||||||
pass
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
from .plugin_base import Plugin
|
||||||
|
from .torch_ddp_plugin import TorchDDPPlugin
|
||||||
|
|
||||||
|
__all__ = ['Plugin', 'TorchDDPPlugin']
|
|
@ -0,0 +1,51 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.booster.interface import OptimizerWrapper
|
||||||
|
|
||||||
|
__all__ = ['Plugin']
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(ABC):
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_devices(self) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def supported_precisions(self) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def control_precision(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def control_device(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def support_no_sync(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def configure(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Callable = None,
|
||||||
|
dataloader: DataLoader = None,
|
||||||
|
lr_scheduler: LRScheduler = None,
|
||||||
|
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||||
|
# implement this method
|
||||||
|
pass
|
|
@ -0,0 +1,147 @@
|
||||||
|
import random
|
||||||
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from colossalai.booster.interface import OptimizerWrapper
|
||||||
|
|
||||||
|
from .plugin_base import Plugin
|
||||||
|
|
||||||
|
__all__ = ['TorchDDPPlugin']
|
||||||
|
|
||||||
|
|
||||||
|
class TorchDDPPlugin(Plugin):
|
||||||
|
"""
|
||||||
|
Plugin for PyTorch DDP.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from colossalai.booster import Booster
|
||||||
|
>>> from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
|
>>>
|
||||||
|
>>> model, train_dataset, optimizer, criterion = ...
|
||||||
|
>>> plugin = TorchDDPPlugin()
|
||||||
|
|
||||||
|
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
|
||||||
|
>>> booster = Booster(plugin=plugin)
|
||||||
|
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.
|
||||||
|
bucket_cap_mb (int, optional): The bucket size in MB. Defaults to 25.
|
||||||
|
find_unused_parameters (bool, optional): Whether to find unused parameters. Defaults to False.
|
||||||
|
check_reduction (bool, optional): Whether to check reduction. Defaults to False.
|
||||||
|
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view. Defaults to False.
|
||||||
|
static_graph (bool, optional): Whether to use static graph. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
broadcast_buffers: bool = True,
|
||||||
|
bucket_cap_mb: int = 25,
|
||||||
|
find_unused_parameters: bool = False,
|
||||||
|
check_reduction: bool = False,
|
||||||
|
gradient_as_bucket_view: bool = False,
|
||||||
|
static_graph: bool = False) -> None:
|
||||||
|
|
||||||
|
assert dist.is_initialized(
|
||||||
|
), 'torch.distributed is not initialized, please use colossalai.launch to create the distributed environment'
|
||||||
|
self.rank = dist.get_rank()
|
||||||
|
self.world_size = dist.get_world_size()
|
||||||
|
self.ddp_kwargs = dict(broadcast_buffers=broadcast_buffers,
|
||||||
|
bucket_cap_mb=bucket_cap_mb,
|
||||||
|
find_unused_parameters=find_unused_parameters,
|
||||||
|
check_reduction=check_reduction,
|
||||||
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||||
|
static_graph=static_graph)
|
||||||
|
|
||||||
|
def support_no_sync(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def control_precision(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supported_precisions(self) -> List[str]:
|
||||||
|
return ['fp16', 'fp16_apex', 'bf16', 'fp8']
|
||||||
|
|
||||||
|
def control_device(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supported_devices(self) -> List[str]:
|
||||||
|
return ['cuda']
|
||||||
|
|
||||||
|
def prepare_train_dataloader(self,
|
||||||
|
dataset,
|
||||||
|
batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
**kwargs):
|
||||||
|
r"""
|
||||||
|
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||||
|
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
1. Evaluation datasets should not be passed to this function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||||
|
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||||
|
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||||
|
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||||
|
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||||
|
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||||
|
the batch size, then the last batch will be smaller, defaults to False.
|
||||||
|
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||||
|
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||||
|
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||||
|
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||||
|
"""
|
||||||
|
_kwargs = kwargs.copy()
|
||||||
|
sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank, shuffle=shuffle)
|
||||||
|
|
||||||
|
# Deterministic dataloader
|
||||||
|
def seed_worker(worker_id):
|
||||||
|
worker_seed = seed
|
||||||
|
np.random.seed(worker_seed)
|
||||||
|
torch.manual_seed(worker_seed)
|
||||||
|
random.seed(worker_seed)
|
||||||
|
|
||||||
|
return DataLoader(dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
drop_last=drop_last,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
num_workers=num_workers,
|
||||||
|
**_kwargs)
|
||||||
|
|
||||||
|
def configure(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
criterion: Callable = None,
|
||||||
|
dataloader: DataLoader = None,
|
||||||
|
lr_scheduler: LRScheduler = None,
|
||||||
|
) -> Tuple[Union[nn.Module, OptimizerWrapper, LRScheduler, DataLoader]]:
|
||||||
|
# cast model to cuda
|
||||||
|
model = model.cuda()
|
||||||
|
|
||||||
|
# wrap the model with PyTorch DDP
|
||||||
|
model = DDP(model, **self.ddp_kwargs)
|
||||||
|
|
||||||
|
if not isinstance(optimizer, OptimizerWrapper):
|
||||||
|
optimizer = OptimizerWrapper(optimizer)
|
||||||
|
|
||||||
|
return model, optimizer, criterion, dataloader, lr_scheduler
|
|
@ -1,13 +1,27 @@
|
||||||
import pytest
|
from functools import partial
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.models import resnet18
|
|
||||||
|
|
||||||
from colossalai.booster.accelerator import Accelerator
|
from colossalai.booster.accelerator import Accelerator
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
|
@parameterize('device', ['cpu', 'cuda'])
|
||||||
def test_accelerator(device):
|
def run_accelerator(device):
|
||||||
acceleartor = Accelerator(device)
|
acceleartor = Accelerator(device)
|
||||||
model = nn.Linear(8, 8)
|
model = nn.Linear(8, 8)
|
||||||
model = acceleartor.configure_model(model)
|
model = acceleartor.configure_model(model)
|
||||||
assert next(model.parameters()).device.type == device
|
assert next(model.parameters()).device.type == device
|
||||||
|
del model, acceleartor
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank):
|
||||||
|
run_accelerator()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_accelerator():
|
||||||
|
world_size = 1
|
||||||
|
run_func = partial(run_dist)
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
|
@ -1,12 +1,21 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
|
|
||||||
|
import colossalai
|
||||||
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
|
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
from tests.kit.model_zoo import model_zoo
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
def test_torch_amp():
|
def run_torch_amp(rank, world_size, port):
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
# init dist env
|
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('timm')
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
|
||||||
# dlrm_interactionarch has not parameters, so skip
|
# dlrm_interactionarch has not parameters, so skip
|
||||||
if name == 'dlrm_interactionarch':
|
if name == 'dlrm_interactionarch':
|
||||||
continue
|
continue
|
||||||
|
@ -27,3 +36,11 @@ def test_torch_amp():
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
optimizer.clip_grad_by_norm(1.0)
|
optimizer.clip_grad_by_norm(1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
del model, optimizer, criterion, data, output, mixed_precision
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_torch_ddp_plugin():
|
||||||
|
world_size = 1
|
||||||
|
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import SGD
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.interface import OptimizerWrapper
|
||||||
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
|
||||||
|
def check_torch_ddp_plugin():
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||||
|
if name == 'dlrm_interactionarch':
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = model_fn()
|
||||||
|
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
data = data_gen_fn()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
|
assert isinstance(model, DDP)
|
||||||
|
assert isinstance(optimizer, OptimizerWrapper)
|
||||||
|
|
||||||
|
output = model(**data)
|
||||||
|
output = output_transform_fn(output)
|
||||||
|
output_key = list(output.keys())[0]
|
||||||
|
loss = criterion(output[output_key])
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.clip_grad_by_norm(1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
def check_dataloader_sharding():
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
|
||||||
|
# create a custom dasetset with 0 to 10
|
||||||
|
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
|
||||||
|
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
|
||||||
|
|
||||||
|
# get the first batch of data
|
||||||
|
batch = next(iter(train_dataloader))[0].cuda()
|
||||||
|
is_rank_0 = dist.get_rank() == 0
|
||||||
|
|
||||||
|
if is_rank_0:
|
||||||
|
batch_to_compare = batch.clone()
|
||||||
|
else:
|
||||||
|
batch_to_compare = batch
|
||||||
|
# pass to the rank 1 value to rank 0
|
||||||
|
dist.broadcast(batch_to_compare, src=1)
|
||||||
|
|
||||||
|
# compare on rank 0
|
||||||
|
if is_rank_0:
|
||||||
|
assert not torch.equal(batch,
|
||||||
|
batch_to_compare), 'Same number was found across ranks but expected it to be different'
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
# init dist env
|
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
check_dataloader_sharding()
|
||||||
|
check_torch_ddp_plugin()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_torch_ddp_plugin():
|
||||||
|
world_size = 2
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
Loading…
Reference in New Issue