From ae1b58cd16ca6363c31bf4681607414c831c32d4 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 12 Aug 2022 11:33:09 +0800 Subject: [PATCH] [tensor] added linear implementation for the new sharding spec (#1416) * [tensor] added linear implementation for the new sharding spec * polish code --- colossalai/nn/_ops/linear.py | 80 +++++++- colossalai/tensor/sharding_spec.py | 7 +- tests/test_tensor/test_sharded_linear.py | 236 +++++++++++++++++++++++ 3 files changed, 316 insertions(+), 7 deletions(-) create mode 100644 tests/test_tensor/test_sharded_linear.py diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 699667b6c..8835574de 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -4,6 +4,8 @@ from ._utils import GeneralTensor, convert_to_colo_tensor from colossalai.tensor.op_wrapper import colo_op_impl from ._utils import reduce_input, reduce_grad from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from copy import deepcopy def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': @@ -86,8 +88,84 @@ def colo_linear_imp(input_tensor: GeneralTensor, return ret_tensor +def _new_colo_linear_imp(input_tensor: GeneralTensor, + weight: GeneralTensor, + bias: Optional[GeneralTensor] = None) -> 'ColoTensor': + """ + A tentative function to compute the distributed linear layer with the latest sharding spec. + This function is subject to future change as the current sharding API is not stable. + """ + # get mesh info + input_sharding_seq = input_tensor.sharding_spec.sharding_sequence + weight_sharding_seq = weight.sharding_spec.sharding_sequence + if bias is not None: + bias_sharding_seq = bias.sharding_spec.sharding_sequence + device_mesh = weight.sharding_spec.device_mesh + pg_axis0 = weight.pg_axis0 + pg_axis1 = weight.pg_axis1 + + # the last dim of input should have the same spec as the first dim of weight + # the weight is transposed, so we look at the second dimension + assert input_sharding_seq[-1] == weight_sharding_seq[1] + + if bias is not None: + assert bias_sharding_seq[0] == weight_sharding_seq[0] + + # compute the output sharding sequence + # as weight is transposed, so we look at the first dimension + output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1] + output_shard_seq = deepcopy(output_shard_seq) + + # TODO: add reduce grad logic + + # handle column and row parallel linear + # by reusing the implementation above + out = F.linear(input_tensor, weight) + + # run all reduce if necessary + last_dim_spec = input_sharding_seq[-1] + if last_dim_spec.is_replica: + pass + elif last_dim_spec.shard_list is not None: + for dim in last_dim_spec.shard_list: + if dim == 0: + reduce_input(out, pg_axis0) + elif dim == 1: + reduce_input(out, pg_axis1) + else: + raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected") + # add bias + if bias is not None: + out += bias + + # convert shard seq to partition dict + output_partition_dict = {} + for index, dim_spec in enumerate(output_shard_seq): + if not dim_spec.is_replica: + if index not in output_partition_dict: + output_partition_dict[index] = [] + output_partition_dict[index].extend(dim_spec.shard_list) + + entire_shape = out.shape + output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict) + ret_tensor = ColoTensor.from_torch_tensor(out) + setattr(ret_tensor, 'sharding_spec', output_sharding_spec) + return ret_tensor + + +def _has_sharding_spec(tensor): + """ + A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is + set as the attribute `sharding_spec` on a tensor. + """ + return hasattr(tensor, 'sharding_spec') + + @colo_op_impl(F.linear) def colo_linear(input_tensor: GeneralTensor, weight: GeneralTensor, bias: Optional[GeneralTensor] = None) -> 'ColoTensor': - return colo_linear_imp(input_tensor, weight, bias) + if _has_sharding_spec(weight): + return _new_colo_linear_imp(input_tensor, weight, bias) + else: + return colo_linear_imp(input_tensor, weight, bias) diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index e4f7f2490..7fa68b05b 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -17,12 +17,7 @@ class _DimSpec: self.shard_list = shard_list def __eq__(self, other): - if dir(self) != dir(other): - return False - for attr in dir(self): - if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr): - return False - return True + return str(self) == str(other) def __repr__(self): if self.is_replica: diff --git a/tests/test_tensor/test_sharded_linear.py b/tests/test_tensor/test_sharded_linear.py new file mode 100644 index 000000000..96f7d8c0f --- /dev/null +++ b/tests/test_tensor/test_sharded_linear.py @@ -0,0 +1,236 @@ +from lib2to3 import pgen2 +import colossalai +import torch +import pytest +import torch.multiprocessing as mp +import torch.nn.functional as F +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from functools import partial +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup +from colossalai.nn._ops._utils import gather_forward_split_backward + + +def run_dist(rank, world_size, port): + config = {} + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + # create mlp vars + x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda() + w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda() + b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda() + + # run normal forward + out = F.linear(x, w, b) + + # create mesh meta + # the mesh is in the following topo + # [[0, 1], + # [2, 3]] + physical_mesh_id = torch.arange(0, 4).reshape(2, 2) + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + row_id = rank // 2 + column_id = rank % 2 + + # create pg + row_process_group = None + col_process_group = None + row_to_ranks = {0: [0, 1], 1: [2, 3]} + col_to_ranks = {0: [0, 2], 1: [1, 3]} + + for idx in range(2): + # row ranks + row_ranks = row_to_ranks[idx] + row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2) + + # col ranks + col_ranks = col_to_ranks[idx] + col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2) + + if rank in row_ranks: + row_process_group = row_pg + + if rank in col_ranks: + col_process_group = col_pg + + ######################## + # RRR x RS0 -> RRS0 # + ######################## + # w will be transposed in F.linear + x_replica = x.detach().clone() + w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id] + b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id] + + # adding sharding spec + x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]}) + + # check sharding spec + assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_replica, w_shard, b_shard) + assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # S0RR x RS1 -> S0RS1 # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id] + w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id] + b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id] + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_shard) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] + expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # S0RS1 x S1R -> S0RR # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id] + x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RRS0 x S0R -> RRR # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = out + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RS0S1 x S1R -> RS0R # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id] + x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id] + b_replica = b.clone() + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]}) + b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]" + assert str(b_replica.sharding_spec.sharding_sequence) == "[R]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_replica) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id] + assert torch.allclose(out_shard, expected_out_shard) + + ######################## + # RRS0 x S0S1 -> RRS1 # + ######################## + # w will be transposed in F.linear + x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id] + w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id] + w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id] + b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id] + + # adding sharding spec + x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]}) + w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]}) + b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]}) + + # check sharding spec + assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]" + assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]" + assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]" + + w_shard.pg_axis0 = col_process_group + w_shard.pg_axis1 = row_process_group + + out_shard = F.linear(x_shard, w_shard, b_shard) + + # each row only has a mini-batch + expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id] + assert torch.allclose(out_shard, expected_out_shard) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@rerun_if_address_is_in_use() +def test_sharded_mlp(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_sharded_mlp(4)