diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index d8bc338a5..f9b1d2815 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -1,9 +1,10 @@ import torch from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.colo_tensor import ColoTensor +from colossalai.context import ParallelMode +from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input from packaging import version - @colo_op_impl(torch.nn.functional.linear) def colo_linear(types, args, kwargs, pg): """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. @@ -19,12 +20,31 @@ def colo_linear(types, args, kwargs, pg): bias = None else: bias = kwargs.get('bias', None) - + if isinstance(bias, ColoTensor): bias = bias.torch_tensor() # Add communication logic before and after linear call. if isinstance(weight, ColoTensor): - return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) + if weight.shard_spec == None: + return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) + elif weight.shard_spec == '1Drow': + """ + Input:S[1] x Weight:S[0] = Output:P + All-Reduce(Output) + bias = res + """ + # Input:S[1] + input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1) + # Output:P + partial_output = torch.nn.functional.linear(input_per_partition, weight.torch_tensor()) + # Reduce(Output) + output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) + # Bias + if bias is not None: + output = output + bias + return output + + else: + raise NotImplementedError else: return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 8d7e96120..f72cd02be 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -4,7 +4,6 @@ from typing import Tuple import numpy from .op_wrapper import _COLOSSAL_OPS - class ColoTensor(object): """ Data Structure for Tensor in Colossal-AI 1. It contains a torch.Tensor as an attribute. @@ -24,6 +23,7 @@ class ColoTensor(object): pin_memory=False, device=None, torch_tensor=torch.empty(0), + shard_spec: str = None, ): self._size = size self._dtype = dtype @@ -31,11 +31,29 @@ class ColoTensor(object): self._pin_memory = pin_memory self._device = device self._torch_tensor = torch_tensor + self._shard_spec = shard_spec + + @property + def shard_spec(self) -> Optional[str]: + return self._shard_spec + + @property + def data(self): + return self._torch_tensor.data + + @property + def grad(self): + return self._torch_tensor.grad + + @property + def size(self): + return self._size def numel(self): return product(self._size) @staticmethod + def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor': colo_t = ColoTensor(*tensor.size(), dtype=tensor.dtype, diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py new file mode 100644 index 000000000..a6147463a --- /dev/null +++ b/tests/test_tensor/test_linear_tp.py @@ -0,0 +1,91 @@ +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.tensor import ColoTensor + +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +import torch.distributed as dist + +from test_tensor_utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk + +def run_linear_tp1d_row_test(): + device = get_current_device() + dtype = torch.float32 + DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D) + in_features = 4 + out_features = 5 + + local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer_master = torch.nn.Linear(in_features, out_features) + layer = torch.nn.Linear(in_features, out_features) + + A_shape = (2, in_features) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + A = broadcast_tensor_chunk(A_master, chunk_size=1) + A.requires_grad = True + + W_shape = (out_features, in_features) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + W = broadcast_tensor_chunk(W_master, chunk_size=DEPTH, local_rank=local_rank) + W.requires_grad = True + + B_shape = (out_features) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + B = broadcast_tensor_chunk(B_master, chunk_size=1) + B.requires_grad = True + + # replace the torch nn.Parameters with ColoTensor + sharded_weight = ColoTensor.init_from_torch_tensor(W) + sharded_weight._shard_spec = "1Drow" + sharded_bias = ColoTensor.init_from_torch_tensor(B) + replace_parameter_add_grad(layer, sharded_weight, sharded_bias) + out = layer(A) + + replace_parameter_add_grad(layer_master, W_master, B_master) + A_master.requires_grad = True + #C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad = broadcast_tensor_chunk(grad_master, chunk_size=1) + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + check_equal(B_grad, layer.bias.grad) + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_linear_tp1d_row_test() + + +@pytest.mark.dist +@parameterize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_linear_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_1d() diff --git a/tests/test_tensor/test_tensor_utils/__init__.py b/tests/test_tensor/test_tensor_utils/__init__.py new file mode 100644 index 000000000..8b2ce749a --- /dev/null +++ b/tests/test_tensor/test_tensor_utils/__init__.py @@ -0,0 +1 @@ +from ._util import * \ No newline at end of file diff --git a/tests/test_tensor/test_tensor_utils/_util.py b/tests/test_tensor/test_tensor_utils/_util.py new file mode 100644 index 000000000..88a938879 --- /dev/null +++ b/tests/test_tensor/test_tensor_utils/_util.py @@ -0,0 +1,20 @@ +import torch +import torch.distributed as dist + +def check_equal(A, B): + assert torch.allclose(A, B, rtol=1e-3, atol=1e-1) == True + +def replace_parameter_add_grad(layer, weight=None, bias=None): + if weight is not None: + delattr(layer, 'weight') + setattr(layer, 'weight', weight) + layer.weight.requires_grad = True + if bias is not None: + delattr(layer, 'bias') + setattr(layer, 'bias', bias) + layer.bias.requires_grad = True + +def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): + dist.broadcast(tensor, src=0) + tensor_chunk = torch.chunk(tensor, chunk_size, dim=-1)[local_rank] + return tensor_chunk.clone() \ No newline at end of file