mirror of https://github.com/hpcaitech/ColossalAI
[booster] fix no_sync method (#3709)
* [booster] fix no_sync method * [booster] add test for ddp no_sync * [booster] fix merge * [booster] update unit test * [booster] update unit test * [booster] update unit testpull/3713/head
parent
3bf09efe74
commit
6552cbf8e1
|
@ -2,7 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -286,3 +286,6 @@ class GeminiPlugin(DPPluginBase):
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return GeminiCheckpointIO()
|
return GeminiCheckpointIO()
|
||||||
|
|
||||||
|
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -197,3 +197,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return LowLevelZeroCheckpointIO()
|
return LowLevelZeroCheckpointIO()
|
||||||
|
|
||||||
|
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, Iterator, List, Tuple, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -60,6 +60,13 @@ class Plugin(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||||
|
"""
|
||||||
|
Context manager to disable gradient synchronization.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare_dataloader(self,
|
def prepare_dataloader(self,
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, Iterator, List, Tuple, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
@ -142,3 +142,7 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return TorchDDPCheckpointIO()
|
return TorchDDPCheckpointIO()
|
||||||
|
|
||||||
|
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||||
|
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
|
||||||
|
return model.module.no_sync()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, Iterator, List, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -49,6 +49,9 @@ class DPPluginWrapper(DPPluginBase):
|
||||||
def supported_precisions(self) -> List[str]:
|
def supported_precisions(self) -> List[str]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def check_dataloader_sharding():
|
def check_dataloader_sharding():
|
||||||
plugin = DPPluginWrapper()
|
plugin = DPPluginWrapper()
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import SGD
|
from torch.optim import SGD
|
||||||
|
|
||||||
|
@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModel(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.rand(1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.weight * x
|
||||||
|
|
||||||
|
|
||||||
|
def check_torch_ddp_no_sync():
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
model = DummyModel()
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||||
|
# create a custom dasetset with 0 to 10
|
||||||
|
dataset = torch.arange(0, 10)
|
||||||
|
train_dataloader = plugin.prepare_dataloader(dataset, batch_size=2)
|
||||||
|
model, optimizer, criterion, train_dataloader, _ = booster.boost(model,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
dataloader=train_dataloader)
|
||||||
|
|
||||||
|
def fwd_bwd():
|
||||||
|
output = model(batch.cuda())
|
||||||
|
loss = criterion(output)
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
|
||||||
|
def get_grad_set_over_all_ranks():
|
||||||
|
for p in model.parameters():
|
||||||
|
# grad shape is (1, )
|
||||||
|
assert p.grad.shape == (1,)
|
||||||
|
grad_list = [torch.empty_like(p.grad) for _ in range(dist.get_world_size())]
|
||||||
|
dist.all_gather(grad_list, p.grad)
|
||||||
|
# get grad set of all ranks
|
||||||
|
grad_set = set([grad.item() for grad in grad_list])
|
||||||
|
# as the model only has one parameter, we can return here
|
||||||
|
return grad_set
|
||||||
|
|
||||||
|
for i, batch in enumerate(train_dataloader):
|
||||||
|
if i > 1:
|
||||||
|
# only check the first two batches
|
||||||
|
break
|
||||||
|
# no_sync for the first batch, sync for the second batch
|
||||||
|
ctx = booster.no_sync(model) if i == 0 else nullcontext()
|
||||||
|
with ctx:
|
||||||
|
fwd_bwd()
|
||||||
|
grad_set = get_grad_set_over_all_ranks()
|
||||||
|
# for the first batch, all ranks should have different grads
|
||||||
|
# for the second batch, as grad is synchronized,all ranks should have the same grads
|
||||||
|
target_num_different_grad = dist.get_world_size() if i == 0 else 1
|
||||||
|
assert len(grad_set) == target_num_different_grad
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
# init dist env
|
# init dist env
|
||||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
check_torch_ddp_plugin()
|
check_torch_ddp_plugin()
|
||||||
|
check_torch_ddp_no_sync()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
Loading…
Reference in New Issue