From 0803a61412c5f57b7f784fbba19aa92f33cf6885 Mon Sep 17 00:00:00 2001 From: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Date: Thu, 22 Jun 2023 14:40:37 +0800 Subject: [PATCH] [shardformer] add linearconv1d test (#4067) * add linearconv1d test * add linearconv1d test --- colossalai/shardformer/layer/linear_conv.py | 36 +++--- colossalai/shardformer/policies/gpt2.py | 10 +- .../test_layer/test_linearconv_1d.py | 107 ++++++++++++++++++ .../test_model/test_shard_gpt2.py | 3 - 4 files changed, 122 insertions(+), 34 deletions(-) create mode 100644 tests/test_shardformer/test_layer/test_linearconv_1d.py diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/linear_conv.py index 2adfc1828..2d1dacf2c 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -103,10 +103,15 @@ class LinearConv1D_Col(ParallelModule): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int, *args, **kwargs) -> ParallelModule: r""" Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer. + + Args: + module (`nn.Linear`): The module to be converted. + process_group (`Union[ProcessGroup, List[ProcessGroup]]`): The process group to be used for weight sharding and communication. + n_fused (int): The number of layers to be fused. In GPT2, Q,K,V are fused in one weight. """ # get the attributes in_features = module.weight.shape[0] @@ -135,20 +140,20 @@ class LinearConv1D_Col(ParallelModule): # first rearange the order of weight and bias world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_cast) + order = torch.arange(world_size * n_fused) new_order = [] for i in range(world_size): new_order.append(order[i::world_size]) new_order = torch.cat(new_order) - weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=1) + weight_chunks = torch.chunk(module.weight.data, world_size * n_fused, dim=1) rearanged_weight_chunks = [weight_chunks[i] for i in new_order] rearanged_weight = torch.cat(rearanged_weight_chunks, dim=1) sharded_weight = shard_colwise(rearanged_weight, process_group) linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) + bias_chunks = torch.chunk(module.bias.data, world_size * n_fused, dim=0) rearanged_bias_chunks = [bias_chunks[i] for i in new_order] rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) sharded_bias = shard_colwise(rearanged_bias, process_group) @@ -260,8 +265,8 @@ class LinearConv1D_Row(ParallelModule): self.reset_parameters(weight_initializer, bias_initializer) @staticmethod - def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], n_cast: int, - *args, **kwargs) -> ParallelModule: + def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, + **kwargs) -> ParallelModule: r""" Convert a native PyTorch linear layer to a parallelized linear layer. """ @@ -289,26 +294,11 @@ class LinearConv1D_Row(ParallelModule): with torch.no_grad(): # the weigh to the linear layer is a transpose # thus shard on col is equal to shard on row - - # first rearange the order of weight and bias - world_size = dist.get_world_size(group=process_group) - order = torch.arange(world_size * n_cast) - new_order = [] - for i in range(world_size): - new_order.append(order[i::world_size]) - new_order = torch.cat(new_order) - - weight_chunks = torch.chunk(module.weight.data, world_size * n_cast, dim=0) - rearanged_weight_chunks = [weight_chunks[i] for i in new_order] - rearanged_weight = torch.cat(rearanged_weight_chunks, dim=0) - sharded_weight = shard_rowwise(rearanged_weight, process_group) + sharded_weight = shard_rowwise(module.weight.data, process_group) linear_1d.weight.data.copy_(sharded_weight.T.contiguous()) if bias: - bias_chunks = torch.chunk(module.bias.data, world_size * n_cast, dim=0) - rearanged_bias_chunks = [bias_chunks[i] for i in new_order] - rearanged_bias = torch.cat(rearanged_bias_chunks, dim=0) - linear_1d.bias.copy_(rearanged_bias.contiguous()) + linear_1d.bias.copy_(module.bias.data) return linear_1d diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 54ea2f6e3..9d5d7d36a 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -44,29 +44,23 @@ class GPT2Policy(Policy): suffix="attn.c_attn", target_module=col_nn.LinearConv1D_Col, kwargs={ - "n_cast": 3, + "n_fused": 3, }, ), SubModuleReplacementDescription( suffix="attn.c_proj", target_module=col_nn.LinearConv1D_Row, - kwargs={ - "n_cast": 1, - }, ), SubModuleReplacementDescription( suffix="mlp.c_fc", target_module=col_nn.LinearConv1D_Col, kwargs={ - "n_cast": 1, + "n_fused": 1, }, ), SubModuleReplacementDescription( suffix="mlp.c_proj", target_module=col_nn.LinearConv1D_Row, - kwargs={ - "n_cast": 1, - }, ), SubModuleReplacementDescription( suffix="attn.attn_dropout", diff --git a/tests/test_shardformer/test_layer/test_linearconv_1d.py b/tests/test_shardformer/test_layer/test_linearconv_1d.py new file mode 100644 index 000000000..e0c97178d --- /dev/null +++ b/tests/test_shardformer/test_layer/test_linearconv_1d.py @@ -0,0 +1,107 @@ +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.testing import assert_close + +import colossalai +from colossalai.shardformer.layer import LinearConv1D_Col, LinearConv1D_Row +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +# This code is copied from https://github.com/huggingface/transformers +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + + Basically works like a linear layer but the weights are transposed. + + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + self.weight = nn.Parameter(torch.empty(nx, nf)) + self.bias = nn.Parameter(torch.zeros(nf)) + nn.init.normal_(self.weight, std=0.02) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def rearrange(tensor: torch.Tensor, dim: int): + tensor = tensor.clone() + world_size = 2 + order = torch.arange(world_size * 3) + new_order = [] + for i in range(world_size): + new_order.append(order[i::world_size]) + new_order = torch.cat(new_order) + + tensor_chunks = torch.chunk(tensor, world_size * 3, dim=dim) + rearanged_tensor_chunks = [tensor_chunks[i] for i in new_order] + rearanged_tensor = torch.cat(rearanged_tensor_chunks, dim=dim) + return rearanged_tensor + + +def check_linear_conv_1d_col(): + linear = Conv1D(192, 48).cuda() + linear_conv_col = LinearConv1D_Col.from_native_module(linear, process_group=None, gather_output=True, n_fused=3) + + assert linear_conv_col.weight.shape == torch.Size([96, 48]) + assert linear_conv_col.bias.shape == torch.Size([96]) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_conv_col(x) + assert_close(rearrange(out, 1), gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad.transpose(0, 1).contiguous(), linear_conv_col.weight.grad) + + +def check_linear_1d_row(): + linear = Conv1D(192, 48).cuda() + linear_row = LinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False) + + assert linear_row.weight.shape == torch.Size([192, 24]) + assert linear_row.bias.shape == torch.Size([192]) + + # check computation correctness + x = torch.rand(4, 48).cuda() + out = linear(x) + gather_out = linear_row(x) + assert_close(out, gather_out) + + # check backward correctness + out.sum().backward() + gather_out.sum().backward() + + rank = dist.get_rank() + target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank] + assert_close(target_grad, linear_row.weight.grad) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_linear_conv_1d_col() + + +@rerun_if_address_is_in_use() +def test_linearconv(): + spawn(run_dist, nprocs=2) + + +if __name__ == '__main__': + test_linearconv() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 0c07f4440..9aa02ec34 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -42,9 +42,6 @@ def check_gpt2(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - print(name) - # if name == 'transformers_gpt': - # continue org_model, sharded_model = build_model(world_size, model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)