Browse Source

[booster] fix no_sync method (#3709)

* [booster] fix no_sync method

* [booster] add test for ddp no_sync

* [booster] fix merge

* [booster] update unit test

* [booster] update unit test

* [booster] update unit test
pull/3713/head
Hongxin Liu 2 years ago committed by GitHub
parent
commit
6552cbf8e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      colossalai/booster/plugin/gemini_plugin.py
  2. 5
      colossalai/booster/plugin/low_level_zero_plugin.py
  3. 9
      colossalai/booster/plugin/plugin_base.py
  4. 6
      colossalai/booster/plugin/torch_ddp_plugin.py
  5. 5
      tests/test_booster/test_plugin/test_dp_plugin_base.py
  6. 60
      tests/test_booster/test_plugin/test_torch_ddp_plugin.py

5
colossalai/booster/plugin/gemini_plugin.py

@ -2,7 +2,7 @@ import logging
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO() return GeminiCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError

5
colossalai/booster/plugin/low_level_zero_plugin.py

@ -1,5 +1,5 @@
import warnings import warnings
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO() return LowLevelZeroCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError

9
colossalai/booster/plugin/plugin_base.py

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union from typing import Callable, Iterator, List, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -60,6 +60,13 @@ class Plugin(ABC):
""" """
pass pass
@abstractmethod
def no_sync(self, model: nn.Module) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
pass
@abstractmethod @abstractmethod
def prepare_dataloader(self, def prepare_dataloader(self,
dataset: Dataset, dataset: Dataset,

6
colossalai/booster/plugin/torch_ddp_plugin.py

@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union from typing import Callable, Iterator, List, Tuple, Union
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO: def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO() return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()

5
tests/test_booster/test_plugin/test_dp_plugin_base.py

@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union from typing import Callable, Iterator, List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase):
def supported_precisions(self) -> List[str]: def supported_precisions(self) -> List[str]:
pass pass
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass
def check_dataloader_sharding(): def check_dataloader_sharding():
plugin = DPPluginWrapper() plugin = DPPluginWrapper()

60
tests/test_booster/test_plugin/test_torch_ddp_plugin.py

@ -1,5 +1,8 @@
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
torch.cuda.empty_cache() torch.cuda.empty_cache()
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.rand(1))
def forward(self, x):
return self.weight * x
def check_torch_ddp_no_sync():
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = DummyModel()
criterion = lambda x: x.mean()
optimizer = SGD(model.parameters(), lr=1e-3)
# create a custom dasetset with 0 to 10
dataset = torch.arange(0, 10)
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
optimizer,
criterion,
dataloader=train_dataloader)
def fwd_bwd():
output = model(batch.cuda())
loss = criterion(output)
booster.backward(loss, optimizer)
def get_grad_set_over_all_ranks():
for p in model.parameters():
# grad shape is (1, )
assert p.grad.shape == (1,)
grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())]
dist.all_gather(grad_list, p.grad)
# get grad set of all ranks
grad_set = set([grad.item() for grad in grad_list])
# as the model only has one parameter, we can return here
return grad_set
for i, batch in enumerate(train_dataloader):
if i > 1:
# only check the first two batches
break
# no_sync for the first batch, sync for the second batch
ctx = booster.no_sync(model) if i == 0 else nullcontext()
with ctx:
fwd_bwd()
grad_set = get_grad_set_over_all_ranks()
# for the first batch, all ranks should have different grads
# for the second batch, as grad is synchronized,all ranks should have the same grads
target_num_different_grad = dist.get_world_size() if i == 0 else 1
assert len(grad_set) == target_num_different_grad
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
# init dist env # init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_ddp_plugin() check_torch_ddp_plugin()
check_torch_ddp_no_sync()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()

Loading…
Cancel
Save