mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
93 lines
3.3 KiB
93 lines
3.3 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from functools import partial
|
|
|
|
import colossalai
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.testing import rerun_if_address_is_in_use
|
|
from colossalai.utils import free_port
|
|
from colossalai.zero.init_ctx import ZeroInitContext
|
|
from colossalai.zero.shard_utils import TensorShardStrategy
|
|
from torchvision.models import resnet50
|
|
|
|
|
|
def run_dist(rank, world_size, port):
|
|
# this test only runs on resnet18
|
|
# as this model has sync batch normalization
|
|
# need to configure cudnn deterministic so that
|
|
# randomness of convolution layers will be disabled
|
|
zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy()))
|
|
colossalai.launch(config=dict(zero=zero_config, cudnn_determinstic=True, cudnn_benchmark=False),
|
|
rank=rank,
|
|
world_size=world_size,
|
|
host='localhost',
|
|
port=port,
|
|
backend='nccl')
|
|
|
|
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
|
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
|
shard_param=True):
|
|
model = resnet50()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
engine, *args = colossalai.initialize(model, optimizer, criterion)
|
|
|
|
# train for dummy iterations
|
|
engine.train()
|
|
for _ in range(2):
|
|
data = torch.rand(4, 3, 128, 128).cuda().half()
|
|
label = torch.randint(0, 10, size=(4,)).cuda()
|
|
engine.zero_grad()
|
|
out = engine(data)
|
|
loss = engine.criterion(out, label)
|
|
engine.backward(loss)
|
|
engine.step()
|
|
|
|
# test
|
|
# need to make sure the batch norm stats are synchronized
|
|
# so that given the same input, the model will produce the same
|
|
# output on different ranks
|
|
engine.eval()
|
|
data = torch.rand(4, 3, 128, 128).cuda().half()
|
|
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
# predict
|
|
out = engine(data)
|
|
|
|
# test if results are equal
|
|
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
|
|
tensor_list.insert(rank, out)
|
|
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
|
|
|
|
assert torch.all(tensor_list[0] == tensor_list[1]), \
|
|
'expected the output from different ranks to be the same, but got different values'
|
|
|
|
|
|
@pytest.mark.dist
|
|
@rerun_if_address_is_in_use()
|
|
def test_sharded_optim_with_sync_bn():
|
|
"""
|
|
This test is to make sure that buffers are synchronized between ranks
|
|
when using ZeRO. An example of module buffer is the running stats of
|
|
BatchNormalization layer, i.e. mean and var.
|
|
|
|
If the buffers are not synchronized, the model will produce different
|
|
output even though the input and parameters are the same. This is not
|
|
wanted if we are doing predictions.
|
|
|
|
"""
|
|
world_size = 2
|
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
|
mp.spawn(run_func, nprocs=world_size)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_sharded_optim_with_sync_bn()
|