mirror of https://github.com/hpcaitech/ColossalAI
[booster] fix no_sync method (#3709)
* [booster] fix no_sync method * [booster] add test for ddp no_sync * [booster] fix merge * [booster] update unit test * [booster] update unit test * [booster] update unit testpull/3713/head
parent
3bf09efe74
commit
6552cbf8e1
|
@ -2,7 +2,7 @@ import logging
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase):
|
|||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return GeminiCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import warnings
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return LowLevelZeroCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -60,6 +60,13 @@ class Plugin(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
"""
|
||||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dataloader(self,
|
||||
dataset: Dataset,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchDDPCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
|
||||
return model.module.no_sync()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, Iterator, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase):
|
|||
def supported_precisions(self) -> List[str]:
|
||||
pass
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
pass
|
||||
|
||||
|
||||
def check_dataloader_sharding():
|
||||
plugin = DPPluginWrapper()
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import SGD
|
||||
|
||||
|
@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.rand(1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.weight * x
|
||||
|
||||
|
||||
def check_torch_ddp_no_sync():
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = DummyModel()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||
# create a custom dasetset with 0 to 10
|
||||
dataset = torch.arange(0, 10)
|
||||
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
|
||||
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
dataloader=train_dataloader)
|
||||
|
||||
def fwd_bwd():
|
||||
output = model(batch.cuda())
|
||||
loss = criterion(output)
|
||||
booster.backward(loss, optimizer)
|
||||
|
||||
def get_grad_set_over_all_ranks():
|
||||
for p in model.parameters():
|
||||
# grad shape is (1, )
|
||||
assert p.grad.shape == (1,)
|
||||
grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(grad_list, p.grad)
|
||||
# get grad set of all ranks
|
||||
grad_set = set([grad.item() for grad in grad_list])
|
||||
# as the model only has one parameter, we can return here
|
||||
return grad_set
|
||||
|
||||
for i, batch in enumerate(train_dataloader):
|
||||
if i > 1:
|
||||
# only check the first two batches
|
||||
break
|
||||
# no_sync for the first batch, sync for the second batch
|
||||
ctx = booster.no_sync(model) if i == 0 else nullcontext()
|
||||
with ctx:
|
||||
fwd_bwd()
|
||||
grad_set = get_grad_set_over_all_ranks()
|
||||
# for the first batch, all ranks should have different grads
|
||||
# for the second batch, as grad is synchronized,all ranks should have the same grads
|
||||
target_num_different_grad = dist.get_world_size() if i == 0 else 1
|
||||
assert len(grad_set) == target_num_different_grad
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_torch_ddp_plugin()
|
||||
check_torch_ddp_no_sync()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
Loading…
Reference in New Issue