ColossalAI/tests/test_zero/test_sharded_optim_with_syn...

93 lines
3.3 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from torchvision.models import resnet50
def run_dist(rank, world_size, port):
# this test only runs on resnet18
# as this model has sync batch normalization
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
model = resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, *args = colossalai.initialize(model, optimizer, criterion)
# train for dummy iterations
engine.train()
for _ in range(2):
data = torch.rand(4, 3, 128, 128).cuda().half()
label = torch.randint(0, 10, size=(4,)).cuda()
engine.zero_grad()
out = engine(data)
loss = engine.criterion(out, label)
engine.backward(loss)
engine.step()
# test
# need to make sure the batch norm stats are synchronized
# so that given the same input, the model will produce the same
# output on different ranks
engine.eval()
data = torch.rand(4, 3, 128, 128).cuda().half()
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
# predict
out = engine(data)
# test if results are equal
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
tensor_list.insert(rank, out)
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
assert torch.all(tensor_list[0] == tensor_list[1]), \
'expected the output from different ranks to be the same, but got different values'
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_sharded_optim_with_sync_bn():
"""
This test is to make sure that buffers are synchronized between ranks
when using ZeRO. An example of module buffer is the running stats of
BatchNormalization layer, i.e. mean and var.
If the buffers are not synchronized, the model will produce different
output even though the input and parameters are the same. This is not
wanted if we are doing predictions.
"""
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_with_sync_bn()