mirror of https://github.com/hpcaitech/ColossalAI
[tensor] added linear implementation for the new sharding spec (#1416)
* [tensor] added linear implementation for the new sharding spec * polish codepull/1439/head
parent
d40a9392ba
commit
ae1b58cd16
|
@ -4,6 +4,8 @@ from ._utils import GeneralTensor, convert_to_colo_tensor
|
|||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ReplicaSpec, ColoTensorSpec
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def colo_linear_1drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
|
@ -86,8 +88,84 @@ def colo_linear_imp(input_tensor: GeneralTensor,
|
|||
return ret_tensor
|
||||
|
||||
|
||||
def _new_colo_linear_imp(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
"""
|
||||
A tentative function to compute the distributed linear layer with the latest sharding spec.
|
||||
This function is subject to future change as the current sharding API is not stable.
|
||||
"""
|
||||
# get mesh info
|
||||
input_sharding_seq = input_tensor.sharding_spec.sharding_sequence
|
||||
weight_sharding_seq = weight.sharding_spec.sharding_sequence
|
||||
if bias is not None:
|
||||
bias_sharding_seq = bias.sharding_spec.sharding_sequence
|
||||
device_mesh = weight.sharding_spec.device_mesh
|
||||
pg_axis0 = weight.pg_axis0
|
||||
pg_axis1 = weight.pg_axis1
|
||||
|
||||
# the last dim of input should have the same spec as the first dim of weight
|
||||
# the weight is transposed, so we look at the second dimension
|
||||
assert input_sharding_seq[-1] == weight_sharding_seq[1]
|
||||
|
||||
if bias is not None:
|
||||
assert bias_sharding_seq[0] == weight_sharding_seq[0]
|
||||
|
||||
# compute the output sharding sequence
|
||||
# as weight is transposed, so we look at the first dimension
|
||||
output_shard_seq = input_sharding_seq[:-1] + weight_sharding_seq[:1]
|
||||
output_shard_seq = deepcopy(output_shard_seq)
|
||||
|
||||
# TODO: add reduce grad logic
|
||||
|
||||
# handle column and row parallel linear
|
||||
# by reusing the implementation above
|
||||
out = F.linear(input_tensor, weight)
|
||||
|
||||
# run all reduce if necessary
|
||||
last_dim_spec = input_sharding_seq[-1]
|
||||
if last_dim_spec.is_replica:
|
||||
pass
|
||||
elif last_dim_spec.shard_list is not None:
|
||||
for dim in last_dim_spec.shard_list:
|
||||
if dim == 0:
|
||||
reduce_input(out, pg_axis0)
|
||||
elif dim == 1:
|
||||
reduce_input(out, pg_axis1)
|
||||
else:
|
||||
raise RuntimeError("Found invalid sharding axis {dim}, only 0 or 1 is expected")
|
||||
# add bias
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
# convert shard seq to partition dict
|
||||
output_partition_dict = {}
|
||||
for index, dim_spec in enumerate(output_shard_seq):
|
||||
if not dim_spec.is_replica:
|
||||
if index not in output_partition_dict:
|
||||
output_partition_dict[index] = []
|
||||
output_partition_dict[index].extend(dim_spec.shard_list)
|
||||
|
||||
entire_shape = out.shape
|
||||
output_sharding_spec = ShardingSpec(device_mesh, entire_shape, output_partition_dict)
|
||||
ret_tensor = ColoTensor.from_torch_tensor(out)
|
||||
setattr(ret_tensor, 'sharding_spec', output_sharding_spec)
|
||||
return ret_tensor
|
||||
|
||||
|
||||
def _has_sharding_spec(tensor):
|
||||
"""
|
||||
A tentative function to check whether the tensor is using the new sharding spec API. We assume that the sharding spec object is
|
||||
set as the attribute `sharding_spec` on a tensor.
|
||||
"""
|
||||
return hasattr(tensor, 'sharding_spec')
|
||||
|
||||
|
||||
@colo_op_impl(F.linear)
|
||||
def colo_linear(input_tensor: GeneralTensor,
|
||||
weight: GeneralTensor,
|
||||
bias: Optional[GeneralTensor] = None) -> 'ColoTensor':
|
||||
return colo_linear_imp(input_tensor, weight, bias)
|
||||
if _has_sharding_spec(weight):
|
||||
return _new_colo_linear_imp(input_tensor, weight, bias)
|
||||
else:
|
||||
return colo_linear_imp(input_tensor, weight, bias)
|
||||
|
|
|
@ -17,12 +17,7 @@ class _DimSpec:
|
|||
self.shard_list = shard_list
|
||||
|
||||
def __eq__(self, other):
|
||||
if dir(self) != dir(other):
|
||||
return False
|
||||
for attr in dir(self):
|
||||
if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr):
|
||||
return False
|
||||
return True
|
||||
return str(self) == str(other)
|
||||
|
||||
def __repr__(self):
|
||||
if self.is_replica:
|
||||
|
|
|
@ -0,0 +1,236 @@
|
|||
from lib2to3 import pgen2
|
||||
import colossalai
|
||||
import torch
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup
|
||||
from colossalai.nn._ops._utils import gather_forward_split_backward
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# create mlp vars
|
||||
x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda()
|
||||
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
|
||||
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
|
||||
|
||||
# run normal forward
|
||||
out = F.linear(x, w, b)
|
||||
|
||||
# create mesh meta
|
||||
# the mesh is in the following topo
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
row_id = rank // 2
|
||||
column_id = rank % 2
|
||||
|
||||
# create pg
|
||||
row_process_group = None
|
||||
col_process_group = None
|
||||
row_to_ranks = {0: [0, 1], 1: [2, 3]}
|
||||
col_to_ranks = {0: [0, 2], 1: [1, 3]}
|
||||
|
||||
for idx in range(2):
|
||||
# row ranks
|
||||
row_ranks = row_to_ranks[idx]
|
||||
row_pg = ProcessGroup(ranks=row_ranks, tp_degree=2)
|
||||
|
||||
# col ranks
|
||||
col_ranks = col_to_ranks[idx]
|
||||
col_pg = ProcessGroup(ranks=col_ranks, tp_degree=2)
|
||||
|
||||
if rank in row_ranks:
|
||||
row_process_group = row_pg
|
||||
|
||||
if rank in col_ranks:
|
||||
col_process_group = col_pg
|
||||
|
||||
########################
|
||||
# RRR x RS0 -> RRS0 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_replica = x.detach().clone()
|
||||
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_replica.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [0]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [0]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_replica.sharding_spec.sharding_sequence) == "[R, R, R]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S0, R]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S0]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_replica, w_shard, b_shard)
|
||||
assert str(out_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# S0RR x RS1 -> S0RS1 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.detach().clone(), chunks=2, dim=0)[row_id]
|
||||
w_shard = torch.chunk(w.detach().clone(), chunks=2, dim=0)[column_id]
|
||||
b_shard = torch.chunk(b.detach().clone(), chunks=2, dim=0)[column_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, R]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, R]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_shard)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
|
||||
expected_out_shard = torch.chunk(expected_out_shard, chunks=2, dim=2)[column_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# S0RS1 x S1R -> S0RR #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=0)[row_id]
|
||||
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={0: [0], 2: [1]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[S0, R, S1]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=0)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RRS0 x S0R -> RRR #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [0]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S0]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = out
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RS0S1 x S1R -> RS0R #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=1)[row_id]
|
||||
x_shard = torch.chunk(x_shard, chunks=2, dim=2)[column_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[column_id]
|
||||
b_replica = b.clone()
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={1: [0], 2: [1]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={1: [1]})
|
||||
b_replica.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, S0, S1]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[R, S1]"
|
||||
assert str(b_replica.sharding_spec.sharding_sequence) == "[R]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_replica)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=1)[row_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
########################
|
||||
# RRS0 x S0S1 -> RRS1 #
|
||||
########################
|
||||
# w will be transposed in F.linear
|
||||
x_shard = torch.chunk(x.clone(), chunks=2, dim=2)[row_id]
|
||||
w_shard = torch.chunk(w.clone(), chunks=2, dim=1)[row_id]
|
||||
w_shard = torch.chunk(w_shard, chunks=2, dim=0)[column_id]
|
||||
b_shard = torch.chunk(b.clone(), chunks=2, dim=0)[column_id]
|
||||
|
||||
# adding sharding spec
|
||||
x_shard.sharding_spec = ShardingSpec(device_mesh, x.shape, dim_partition_dict={2: [0]})
|
||||
w_shard.sharding_spec = ShardingSpec(device_mesh, w.shape, dim_partition_dict={0: [1], 1: [0]})
|
||||
b_shard.sharding_spec = ShardingSpec(device_mesh, b.shape, dim_partition_dict={0: [1]})
|
||||
|
||||
# check sharding spec
|
||||
assert str(x_shard.sharding_spec.sharding_sequence) == "[R, R, S0]"
|
||||
assert str(w_shard.sharding_spec.sharding_sequence) == "[S1, S0]"
|
||||
assert str(b_shard.sharding_spec.sharding_sequence) == "[S1]"
|
||||
|
||||
w_shard.pg_axis0 = col_process_group
|
||||
w_shard.pg_axis1 = row_process_group
|
||||
|
||||
out_shard = F.linear(x_shard, w_shard, b_shard)
|
||||
|
||||
# each row only has a mini-batch
|
||||
expected_out_shard = torch.chunk(out, chunks=2, dim=2)[column_id]
|
||||
assert torch.allclose(out_shard, expected_out_shard)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_sharded_mlp(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_mlp(4)
|
Loading…
Reference in New Issue