From a73130482df257e5efd7bdc88435bad0578cb5e4 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Mon, 12 Jun 2023 13:56:09 +0800 Subject: [PATCH] [shardformer] Unit test (#3928) * fix bug in slicer, add slicer unit test * add dropout test * use pid as dropout seed * updata dropout test with local pattern * ad todo --- colossalai/shardformer/layer/dropout.py | 4 +- colossalai/shardformer/shard/slicer.py | 16 ++-- .../test_module/test_dropout.py | 51 ++++++++++++ .../test_module/test_slicer.py | 78 +++++++++++++++++++ 4 files changed, 139 insertions(+), 10 deletions(-) create mode 100644 tests/test_shardformer/test_module/test_dropout.py create mode 100644 tests/test_shardformer/test_module/test_slicer.py diff --git a/colossalai/shardformer/layer/dropout.py b/colossalai/shardformer/layer/dropout.py index acc114029..0f653a9be 100644 --- a/colossalai/shardformer/layer/dropout.py +++ b/colossalai/shardformer/layer/dropout.py @@ -1,5 +1,4 @@ import os -import time from contextlib import contextmanager import torch @@ -14,7 +13,8 @@ class SeedManager: def __init__(self): original_state = torch.cuda.get_rng_state() - seed = int(f"{int(time.time())}{os.environ['RANK']}") + # TODO: unify this seed manager with the colossalai.context.random + seed = os.getpid() torch.cuda.manual_seed(int(seed)) self.dropout_state = torch.cuda.get_rng_state() torch.cuda.set_rng_state(original_state) diff --git a/colossalai/shardformer/shard/slicer.py b/colossalai/shardformer/shard/slicer.py index 6d35bd193..09e3219f8 100644 --- a/colossalai/shardformer/shard/slicer.py +++ b/colossalai/shardformer/shard/slicer.py @@ -3,7 +3,7 @@ import torch from ..policies.basepolicy import Col_Layer, Layer, Row_Layer from .shard_config import ShardConfig -dim_mapping = {Col_Layer: 1, Row_Layer: 0} +dim_mapping = {Col_Layer: 0, Row_Layer: 1} class Slicer(): @@ -40,7 +40,7 @@ class Slicer(): # print(weight.shape, dim) if policy_layer_cls == Col_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) - bias = self.slice_tensor(bias, 0, True) + bias = self.slice_tensor(bias, 0, True, n_cast) elif policy_layer_cls == Row_Layer: weight = self.slice_tensor(weight, dim, False, n_cast) else: @@ -129,13 +129,13 @@ class Slicer(): """ if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) chunk_list = [ tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) ] - return torch.cat(chunk_list, dim=0).contiguous() + return torch.cat(chunk_list, dim=1).contiguous() def slice_row( self, @@ -152,10 +152,10 @@ class Slicer(): :class:`torch.Tensor`: The sliced tensor """ if n_cast is None: - return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous() + return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous() else: - tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1) + tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0) chunk_list = [ tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size) ] - return torch.cat(chunk_list, dim=1).contiguous() + return torch.cat(chunk_list, dim=0).contiguous() diff --git a/tests/test_shardformer/test_module/test_dropout.py b/tests/test_shardformer/test_module/test_dropout.py new file mode 100644 index 000000000..4a13eb61c --- /dev/null +++ b/tests/test_shardformer/test_module/test_dropout.py @@ -0,0 +1,51 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_dropout(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + + # prepare data + input = torch.randn(5, 4).to('cuda') + dropout = Dropout1D(p=0.4).to('cuda') + output_list = [] + # compare the dropout pattern in each device + for i in range(2): + output = dropout(input) + output_list.append(output) + dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)] + torch.distributed.all_gather(dist_output_list, output) + for j in range(world_size): + for k in range(world_size): + if j != k: + mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0) + assert torch.all( + mask + ) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}" + # compare the dropout pattern in loacl device + for i in range(len(output_list)): + for j in range(len(output_list)): + if i != j: + mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0) + assert torch.all( + mask + ) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_dropout(): + spawn(check_dropout, 2) + + +if __name__ == '__main__': + test_dropout() diff --git a/tests/test_shardformer/test_module/test_slicer.py b/tests/test_shardformer/test_module/test_slicer.py new file mode 100644 index 000000000..c72a03575 --- /dev/null +++ b/tests/test_shardformer/test_module/test_slicer.py @@ -0,0 +1,78 @@ +import pytest +import torch +import torch.nn.functional as F + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer +from colossalai.shardformer.shard.shard_config import ShardConfig +from colossalai.shardformer.shard.slicer import Slicer +from colossalai.testing import rerun_if_address_is_in_use, spawn + +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) + + +def check_slicer(rank, world_size, port, in_feature, out_feature): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl') + # initialize slicer + shardconfig = ShardConfig(rank=rank, world_size=world_size) + slicer = Slicer(shardconfig) + # initialize test data + weight = torch.randn(in_feature, out_feature) + bias = torch.randn(out_feature) + policy_layer_cls_list = [Layer, Col_Layer, Row_Layer] + n_cast_list = [None, 2, 3, 4] + # weight and bias + for n_cast in n_cast_list: + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast) + expected_sliced_weight = weight + expected_sliced_bias = bias + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast) + if (n_cast is None): + expected_sliced_weight = weight.chunk(world_size, dim=0)[rank] + expected_sliced_bias = bias.chunk(world_size)[rank] + else: + chunks = weight.chunk(world_size * n_cast, dim=0) + expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0) + chunks = bias.chunk(world_size * n_cast, dim=0) + expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)]) + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}" + + sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast) + if (n_cast is None): + expected_sliced_weight = weight.chunk(world_size, dim=1)[rank] + expected_sliced_bias = bias + else: + chunks = weight.chunk(world_size * n_cast, dim=1) + expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1) + expected_sliced_bias = bias + assert torch.equal( + sliced_weight, expected_sliced_weight + ), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + assert torch.equal( + sliced_bias, expected_sliced_bias + ), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}" + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_slicer(): + args = dict(in_feature=24, out_feature=48) + spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature']) + + +if __name__ == '__main__': + test_slicer()