From 96211c2cc8fc8e081eddb5111557d1d474b1446d Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 26 Apr 2022 13:23:59 +0800 Subject: [PATCH] [tensor] customized op returns ColoTensor (#875) * [tensor] customized op returns ColoTensor * polish * polish code --- colossalai/tensor/_ops/__init__.py | 3 +-- colossalai/tensor/_ops/element_wise.py | 19 +++++++++++++++++ colossalai/tensor/_ops/init.py | 29 -------------------------- colossalai/tensor/_ops/linear.py | 9 ++++---- colossalai/tensor/spec.py | 8 +++++-- tests/test_tensor/test_net_tp.py | 6 ------ tests/test_tensor/test_op.py | 4 ++-- 7 files changed, 33 insertions(+), 45 deletions(-) delete mode 100644 colossalai/tensor/_ops/init.py diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py index 7438d6ef7..39b279c01 100644 --- a/colossalai/tensor/_ops/__init__.py +++ b/colossalai/tensor/_ops/__init__.py @@ -1,5 +1,4 @@ -from .init import colo_uniform from .linear import colo_linear -from .element_wise import colo_mean +from .element_wise import * from .layernorm import colo_layernorm from .loss import colo_cross_entropy diff --git a/colossalai/tensor/_ops/element_wise.py b/colossalai/tensor/_ops/element_wise.py index 076e3463e..0bc932ec3 100644 --- a/colossalai/tensor/_ops/element_wise.py +++ b/colossalai/tensor/_ops/element_wise.py @@ -29,3 +29,22 @@ def register_elementwise_op(op): register_elementwise_op(torch.nn.functional.gelu) register_elementwise_op(torch.nn.functional.relu) + + +@colo_op_impl(torch.sum) +def sum_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for the elementwise op such + as ``torch.sum`. + This method computes on either a normal tensor or a sharded tensor. + """ + if len(args) > 0: + input_tensor = args[0] + if kwargs is None: + kwargs = {} + if 'input' in kwargs: + input_tensor = kwargs['input'] + # Validate types + if not isinstance(input_tensor, ColoTensor): + raise TypeError("input needs to be a ColoTensor") + return ColoTensor.init_from_torch_tensor(torch.sum(input_tensor.torch_tensor())) diff --git a/colossalai/tensor/_ops/init.py b/colossalai/tensor/_ops/init.py deleted file mode 100644 index 7d4b2cceb..000000000 --- a/colossalai/tensor/_ops/init.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -from colossalai.tensor.op_wrapper import colo_op_impl - - -def validate_param(param, param_name): - if param is None: - raise ValueError(f"param: {param_name} shouldn't be None!") - - -@colo_op_impl(torch.nn.init.uniform_) -def colo_uniform(types, args=(), kwargs=None, pg=None): - r""" - Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform - distribution :math:`\mathcal{U}(a, b)`. - Args: - sharded_tensor: tensor sharded across devices - a: the lower bound of the uniform distribution - b: the upper bound of the uniform distribution - """ - validate_param(kwargs, "kwargs") - stateful_tensor = kwargs["tensor"] - validate_param(stateful_tensor, "stateful_tensor") - a = kwargs['a'] - validate_param(a, "a") - b = kwargs['b'] - validate_param(b, "b") - - torch.nn.init.uniform_(stateful_tensor.torch_tensor(), a=a, b=b) - return stateful_tensor diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 519678480..8ca80b4ca 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -6,7 +6,8 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward from colossalai.nn.layer.utils import divide from colossalai.core import global_context as gpc from packaging import version -from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction +from colossalai.tensor import ComputePattern + @colo_op_impl(torch.nn.functional.linear) def colo_linear(types, args, kwargs, pg): @@ -25,6 +26,7 @@ def colo_linear(types, args, kwargs, pg): bias = kwargs.get('bias', None) if isinstance(bias, ColoTensor): + assert bias.shard_spec.num_action == 0, f"We currently only support bias is duplicated among processes in the linear operator" bias = bias.torch_tensor() # Add communication logic before and after linear call. @@ -34,7 +36,7 @@ def colo_linear(types, args, kwargs, pg): input_tensor = input_tensor.torch_tensor() if isinstance(weight, ColoTensor): weight = weight.torch_tensor() - return torch.nn.functional.linear(input_tensor, weight, bias) + return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) elif weight.shard_spec.num_action == 1: if ComputePattern.TP1DRow in weight.shard_spec.compute_patterns: # Input:S[1] x Weight:S[0] = Output:P @@ -54,8 +56,7 @@ def colo_linear(types, args, kwargs, pg): output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) # Bias if bias is not None: - bias_ = bias - output = output + bias_ + output = output + bias return ColoTensor.init_from_torch_tensor(output) else: raise NotImplementedError diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index ccd85d9cb..18099cc0c 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -2,6 +2,7 @@ from enum import Enum from typing import Tuple, List from colossalai.context.parallel_mode import ParallelMode + class ComputePattern(Enum): TP1DRow = 1 TP1DCol = 2 @@ -10,6 +11,7 @@ class ComputePattern(Enum): class ParallelAction(object): + def __init__(self, priority=0, compute_pattern=ComputePattern.DP, parallel_mode=ParallelMode.DATA) -> None: self.priority = priority self.compute_pattern = compute_pattern @@ -24,6 +26,7 @@ class TensorSpec(object): parallel computation pattern of the Operator (Layer). We have to consider the hybrid parallel mode. """ + # a list of parallel actions. # For example: On 8 GPUs, a hybrid parallel strategy is applied using # using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2. @@ -38,6 +41,7 @@ class TensorSpec(object): # Before Linear Op, we gather the tensors according to ZeRO. # We perform Linear Op according to compute pattern of TP1DRow. # After Linear Op, we split the tensors according to ZeRO. + def __init__(self, parallel_action_list: List[ParallelAction] = []): self._parallel_action_list = parallel_action_list self.sort() @@ -56,8 +60,8 @@ class TensorSpec(object): def sort(self): if len(self._parallel_action_list) > 0: - self._parallel_action_list.sort(key=lambda parallel_action : parallel_action.priority) - + self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority) + def get_action_by_compute_pattern(self, compute_pattern: ComputePattern): for parallel_action in self._parallel_action_list: if parallel_action.compute_pattern == compute_pattern: diff --git a/tests/test_tensor/test_net_tp.py b/tests/test_tensor/test_net_tp.py index f21d1b459..0d5ea848d 100644 --- a/tests/test_tensor/test_net_tp.py +++ b/tests/test_tensor/test_net_tp.py @@ -1,19 +1,13 @@ -from cProfile import label -from statistics import mode -from colossalai.tensor.colo_tensor import ColoTensor from tests.components_to_test.registry import non_distributed_component_funcs import colossalai import pytest -import torch import torch.multiprocessing as mp from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port -from colossalai.core import global_context as gpc from colossalai.utils import ColoInitContext -import torch.distributed as dist from functools import partial diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 7156c536e..40ecc10fe 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -53,11 +53,11 @@ def test_linear(): # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias) out = fc(input_tensor) - loss = out.sum() + loss = torch.sum(out) loss.backward() out_ref = fc_ref(input_ref) - loss_ref = out_ref.sum() + loss_ref = torch.sum(out_ref) loss_ref.backward() assert (loss_ref == loss)