diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4850b52de..a3789a39d 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -2,7 +2,7 @@ import logging import os import warnings from pathlib import Path -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return GeminiCheckpointIO() + + def no_sync(self, model: nn.Module) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index f0f576856..edc0b7679 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Iterator, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return LowLevelZeroCheckpointIO() + + def no_sync(self, model: nn.Module) -> Iterator[None]: + raise NotImplementedError diff --git a/colossalai/booster/plugin/plugin_base.py b/colossalai/booster/plugin/plugin_base.py index eb5478595..561f58bc5 100644 --- a/colossalai/booster/plugin/plugin_base.py +++ b/colossalai/booster/plugin/plugin_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Tuple, Union import torch.nn as nn from torch.optim import Optimizer @@ -60,6 +60,13 @@ class Plugin(ABC): """ pass + @abstractmethod + def no_sync(self, model: nn.Module) -> Iterator[None]: + """ + Context manager to disable gradient synchronization. + """ + pass + @abstractmethod def prepare_dataloader(self, dataset: Dataset, diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 76906d844..99cd2f779 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Tuple, Union import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP @@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase): def get_checkpoint_io(self) -> CheckpointIO: return TorchDDPCheckpointIO() + + def no_sync(self, model: nn.Module) -> Iterator[None]: + assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.' + return model.module.no_sync() diff --git a/tests/test_booster/test_plugin/test_dp_plugin_base.py b/tests/test_booster/test_plugin/test_dp_plugin_base.py index eab949828..61aeded12 100644 --- a/tests/test_booster/test_plugin/test_dp_plugin_base.py +++ b/tests/test_booster/test_plugin/test_dp_plugin_base.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, Iterator, List, Tuple, Union import torch import torch.distributed as dist @@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase): def supported_precisions(self) -> List[str]: pass + def no_sync(self, model: nn.Module) -> Iterator[None]: + pass + def check_dataloader_sharding(): plugin = DPPluginWrapper() diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index 30c4db123..fbe44e5ce 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -1,5 +1,8 @@ +from contextlib import nullcontext + import torch import torch.distributed as dist +import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import SGD @@ -44,10 +47,67 @@ def check_torch_ddp_plugin(): torch.cuda.empty_cache() +class DummyModel(nn.Module): + + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.rand(1)) + + def forward(self, x): + return self.weight * x + + +def check_torch_ddp_no_sync(): + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = DummyModel() + criterion = lambda x: x.mean() + optimizer = SGD(model.parameters(), lr=1e-3) + # create a custom dasetset with 0 to 10 + dataset = torch.arange(0, 10) + train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2) + model, optimizer, criterion, train_dataloader, _ = booster.boost(model, + optimizer, + criterion, + dataloader=train_dataloader) + + def fwd_bwd(): + output = model(batch.cuda()) + loss = criterion(output) + booster.backward(loss, optimizer) + + def get_grad_set_over_all_ranks(): + for p in model.parameters(): + # grad shape is (1, ) + assert p.grad.shape == (1,) + grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())] + dist.all_gather(grad_list, p.grad) + # get grad set of all ranks + grad_set = set([grad.item() for grad in grad_list]) + # as the model only has one parameter, we can return here + return grad_set + + for i, batch in enumerate(train_dataloader): + if i > 1: + # only check the first two batches + break + # no_sync for the first batch, sync for the second batch + ctx = booster.no_sync(model) if i == 0 else nullcontext() + with ctx: + fwd_bwd() + grad_set = get_grad_set_over_all_ranks() + # for the first batch, all ranks should have different grads + # for the second batch, as grad is synchronized,all ranks should have the same grads + target_num_different_grad = dist.get_world_size() if i == 0 else 1 + assert len(grad_set) == target_num_different_grad + + def run_dist(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') check_torch_ddp_plugin() + check_torch_ddp_no_sync() @rerun_if_address_is_in_use()