[booster] fix no_sync method (#3709)

* [booster] fix no_sync method

* [booster] add test for ddp no_sync

* [booster] fix merge

* [booster] update unit test

* [booster] update unit test

* [booster] update unit test
pull/3713/head
Hongxin Liu 2023-05-09 11:10:02 +08:00 committed by GitHub
parent 3bf09efe74
commit 6552cbf8e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 85 additions and 5 deletions

View File

@ -2,7 +2,7 @@ import logging
import os
import warnings
from pathlib import Path
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError

View File

@ -1,5 +1,5 @@
import warnings
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union
import torch.nn as nn
from torch.optim import Optimizer
@ -60,6 +60,13 @@ class Plugin(ABC):
"""
pass
@abstractmethod
def no_sync(self, model: nn.Module) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""
pass
@abstractmethod
def prepare_dataloader(self,
dataset: Dataset,

View File

@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()

View File

@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, Iterator, List, Tuple, Union
import torch
import torch.distributed as dist
@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase):
def supported_precisions(self) -> List[str]:
pass
def no_sync(self, model: nn.Module) -> Iterator[None]:
pass
def check_dataloader_sharding():
plugin = DPPluginWrapper()

View File

@ -1,5 +1,8 @@
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
torch.cuda.empty_cache()
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.rand(1))
def forward(self, x):
return self.weight * x
def check_torch_ddp_no_sync():
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = DummyModel()
criterion = lambda x: x.mean()
optimizer = SGD(model.parameters(), lr=1e-3)
# create a custom dasetset with 0 to 10
dataset = torch.arange(0, 10)
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
optimizer,
criterion,
dataloader=train_dataloader)
def fwd_bwd():
output = model(batch.cuda())
loss = criterion(output)
booster.backward(loss, optimizer)
def get_grad_set_over_all_ranks():
for p in model.parameters():
# grad shape is (1, )
assert p.grad.shape == (1,)
grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())]
dist.all_gather(grad_list, p.grad)
# get grad set of all ranks
grad_set = set([grad.item() for grad in grad_list])
# as the model only has one parameter, we can return here
return grad_set
for i, batch in enumerate(train_dataloader):
if i > 1:
# only check the first two batches
break
# no_sync for the first batch, sync for the second batch
ctx = booster.no_sync(model) if i == 0 else nullcontext()
with ctx:
fwd_bwd()
grad_set = get_grad_set_over_all_ranks()
# for the first batch, all ranks should have different grads
# for the second batch, as grad is synchronized,all ranks should have the same grads
target_num_different_grad = dist.get_world_size() if i == 0 else 1
assert len(grad_set) == target_num_different_grad
def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_torch_ddp_plugin()
check_torch_ddp_no_sync()
@rerun_if_address_is_in_use()