You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_shardformer/test_layer/test_dropout.py

71 lines
2.2 KiB

import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai
from colossalai.shardformer.layer import DropoutForParallelInput, DropoutForReplicatedInput
from colossalai.testing import assert_equal, assert_not_equal, rerun_if_address_is_in_use, spawn
def check_dropout_parallel_input():
dropout = nn.Dropout().cuda()
dropout_1d = DropoutForParallelInput.from_native_module(dropout, process_group=None)
# check computation correctness
x = torch.rand(4, 128).cuda()
# we set seed so that dropout will generate the same mask
torch.cuda.manual_seed(1024)
out = dropout(x)
# we set seed to simulate the same scenario
# but expect the dropout mask to be different
# due to the internal randomness control
torch.cuda.manual_seed(1024)
out_1d = dropout_1d(x)
# ensure out is the same across all ranks
world_size = dist.get_world_size()
out_all = [torch.empty_like(out) for _ in range(world_size)]
dist.all_gather(out_all, out)
for i in range(world_size):
assert_equal(out_all[i], out_all[0])
# ensure out_1d is different across ranks
out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)]
dist.all_gather(out_1d_all, out_1d)
for i in range(1, world_size):
assert_not_equal(out_1d_all[i], out_1d_all[0])
def check_dropout_replicated_input():
dropout = nn.Dropout().cuda()
dropout_replica = DropoutForReplicatedInput.from_native_module(dropout, process_group=None)
# check computation correctness
x = torch.rand(4, 128).cuda()
out_1d = dropout_replica(x)
# ensure out_1d is different across ranks
world_size = dist.get_world_size()
out_1d_all = [torch.zeros_like(out_1d) for _ in range(world_size)]
dist.all_gather(out_1d_all, out_1d)
for i in range(1, world_size):
assert_equal(out_1d_all[i], out_1d_all[0])
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
check_dropout_parallel_input()
check_dropout_replicated_input()
@rerun_if_address_is_in_use()
def test_dropout():
spawn(run_dist, nprocs=2)
if __name__ == "__main__":
test_dropout()