@ -1,5 +1,8 @@
from contextlib import nullcontext
import torch
import torch . distributed as dist
import torch . nn as nn
from torch . nn . parallel import DistributedDataParallel as DDP
from torch . optim import SGD
@ -44,10 +47,67 @@ def check_torch_ddp_plugin():
torch . cuda . empty_cache ( )
class DummyModel ( nn . Module ) :
def __init__ ( self ) :
super ( ) . __init__ ( )
self . weight = nn . Parameter ( torch . rand ( 1 ) )
def forward ( self , x ) :
return self . weight * x
def check_torch_ddp_no_sync ( ) :
plugin = TorchDDPPlugin ( )
booster = Booster ( plugin = plugin )
model = DummyModel ( )
criterion = lambda x : x . mean ( )
optimizer = SGD ( model . parameters ( ) , lr = 1e-3 )
# create a custom dasetset with 0 to 10
dataset = torch . arange ( 0 , 10 )
train_dataloader = plugin . prepare_dataloader ( dataset , batch_size = 2 )
model , optimizer , criterion , train_dataloader , _ = booster . boost ( model ,
optimizer ,
criterion ,
dataloader = train_dataloader )
def fwd_bwd ( ) :
output = model ( batch . cuda ( ) )
loss = criterion ( output )
booster . backward ( loss , optimizer )
def get_grad_set_over_all_ranks ( ) :
for p in model . parameters ( ) :
# grad shape is (1, )
assert p . grad . shape == ( 1 , )
grad_list = [ torch . empty_like ( p . grad ) for _ in range ( dist . get_world_size ( ) ) ]
dist . all_gather ( grad_list , p . grad )
# get grad set of all ranks
grad_set = set ( [ grad . item ( ) for grad in grad_list ] )
# as the model only has one parameter, we can return here
return grad_set
for i , batch in enumerate ( train_dataloader ) :
if i > 1 :
# only check the first two batches
break
# no_sync for the first batch, sync for the second batch
ctx = booster . no_sync ( model ) if i == 0 else nullcontext ( )
with ctx :
fwd_bwd ( )
grad_set = get_grad_set_over_all_ranks ( )
# for the first batch, all ranks should have different grads
# for the second batch, as grad is synchronized,all ranks should have the same grads
target_num_different_grad = dist . get_world_size ( ) if i == 0 else 1
assert len ( grad_set ) == target_num_different_grad
def run_dist ( rank , world_size , port ) :
# init dist env
colossalai . launch ( config = dict ( ) , rank = rank , world_size = world_size , port = port , host = ' localhost ' )
check_torch_ddp_plugin ( )
check_torch_ddp_no_sync ( )
@rerun_if_address_is_in_use ( )