From cb5a4778e1039a0f47c899d65c5f43005f103e0e Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 22 Apr 2022 14:45:57 +0800 Subject: [PATCH] Revert "[WIP] Applying ColoTensor on TP-1D-row Linear. (#831)" (#835) This reverts commit ac88de6dfc69bc59d4cadbd6432b0b818ca37e60. --- colossalai/tensor/_ops/linear.py | 10 +--- colossalai/tensor/colo_tensor.py | 25 +++------- tests/test_tensor/test_linear_tp.py | 74 ----------------------------- 3 files changed, 8 insertions(+), 101 deletions(-) delete mode 100644 tests/test_tensor/test_linear_tp.py diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index d85893969..d8bc338a5 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -19,18 +19,12 @@ def colo_linear(types, args, kwargs, pg): bias = None else: bias = kwargs.get('bias', None) - + if isinstance(bias, ColoTensor): bias = bias.torch_tensor() # Add communication logic before and after linear call. if isinstance(weight, ColoTensor): - if weight.shard_spec == None: - return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) - elif weight.shard_spec == '1Drow': - # TODO(jzy): implement 1Drow TP linear here. - raise NotImplementedError - else: - raise NotImplementedError + return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) else: return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 8900a42ff..f40034dc1 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,6 +1,6 @@ import torch from .op_wrapper import _COLOSSAL_OPS -from typing import Tuple, Optional +from typing import Tuple class ColoTensor(object): @@ -21,35 +21,20 @@ class ColoTensor(object): requires_grad=False, pin_memory=False, torch_tensor=torch.empty(0), - shard_spec: str = None, ): self._size = size self._dtype = dtype self._requires_grad = requires_grad self._pin_memory = pin_memory self._torch_tensor = torch_tensor - self._shard_spec = shard_spec - - @property - def shard_spec(self) -> Optional[str]: - return self._shard_spec - - @property - def data(self): - return self._torch_tensor.data - - @property - def grad(self): - return self._torch_tensor.grad @staticmethod - def init_from_torch_tensor(tensor: torch.Tensor, shard_spec: str = None) -> 'ColoTensor': + def init_from_torch_tensor(tensor: torch.Tensor): colo_t = ColoTensor(*tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, pin_memory=tensor.pin_memory, - torch_tensor=tensor, - shard_spec=shard_spec) + torch_tensor=tensor) return colo_t def del_torch_tensor(self) -> None: @@ -82,5 +67,7 @@ class ColoTensor(object): if kwargs is None: kwargs = {} - kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} + kwargs = { + k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items() + } return func(*args, **kwargs) diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py deleted file mode 100644 index 4adb848b1..000000000 --- a/tests/test_tensor/test_linear_tp.py +++ /dev/null @@ -1,74 +0,0 @@ -from joblib import Parallel -from numpy import allclose, require -import torch -from colossalai.context.parallel_mode import ParallelMode -from colossalai.tensor import ColoTensor -from copy import deepcopy - -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.core import global_context as gpc - - -def run_linear_tp1d_row_test(): - in_dim = 4 - out_dim = 5 - - fc = torch.nn.Linear(in_dim, out_dim, bias=True) - fc_ref = deepcopy(fc) - - input_ref = torch.randn(1, in_dim) - input_tensor = input_ref.clone() - - # sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight, "1Drow") - - # shard weight at begiin - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - sharded_weight = ColoTensor(in_dim / world_size, out_dim, shard_spec="1Drow") - sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias) - - # replace the torch nn.Parameters with ShardedTensor - delattr(fc, 'weight') - setattr(fc, 'weight', sharded_weight) - delattr(fc, 'bias') - setattr(fc, 'bias', sharded_bias) - - fc.weight.requires_grad = True - fc.bias.requires_grad = True - - # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias) - out = fc(input_tensor) - loss = out.sum() - loss.backward() - - out_ref = fc_ref(input_ref) - loss_ref = out_ref.sum() - loss_ref.backward() - - assert (loss_ref == loss) - assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad) - - -def run_dist(rank, world_size, port): - config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) - colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_linear_tp1d_row_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_linear_1d(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_linear_1d(4)