From ae861519687b09934efe431f951b052f7b811b11 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 22 Jun 2022 15:16:47 +0800 Subject: [PATCH] [tensor] add more element-wise ops (#1155) * add more element-wise ops * update test_op * polish unit test --- colossalai/nn/_ops/element_wise.py | 210 ++++++++++++++++++++++++++++- tests/test_tensor/test_op.py | 44 +++++- 2 files changed, 248 insertions(+), 6 deletions(-) diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index ab3dd903b..44de07f83 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -1,5 +1,7 @@ -from copy import copy import torch +import torch.nn.functional as F +from torch import Tensor +from copy import copy from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor import ColoTensor from ._utils import GeneralTensor @@ -21,8 +23,206 @@ def register_elementwise_op(op): return ColoTensor.from_torch_tensor(output) -register_elementwise_op(torch.nn.functional.gelu) -register_elementwise_op(torch.nn.functional.relu) +# Tensor op +register_elementwise_op(Tensor.abs) +register_elementwise_op(Tensor.absolute) +register_elementwise_op(Tensor.acos) +register_elementwise_op(Tensor.arccos) +register_elementwise_op(Tensor.angle) +register_elementwise_op(Tensor.asin) +register_elementwise_op(Tensor.arcsin) +register_elementwise_op(Tensor.atan) +register_elementwise_op(Tensor.arctan) +register_elementwise_op(Tensor.all) +register_elementwise_op(Tensor.any) +register_elementwise_op(Tensor.bernoulli) +register_elementwise_op(Tensor.bfloat16) +register_elementwise_op(Tensor.bitwise_not) +register_elementwise_op(Tensor.bool) +register_elementwise_op(Tensor.byte) +register_elementwise_op(Tensor.ceil) +register_elementwise_op(Tensor.char) +register_elementwise_op(Tensor.clamp) +register_elementwise_op(Tensor.clamp_max) +register_elementwise_op(Tensor.clamp_min) +register_elementwise_op(Tensor.clip) +register_elementwise_op(Tensor.clone) +register_elementwise_op(Tensor.contiguous) +register_elementwise_op(Tensor.copysign) +register_elementwise_op(Tensor.cos) +register_elementwise_op(Tensor.cosh) +register_elementwise_op(Tensor.acosh) +register_elementwise_op(Tensor.arccosh) +register_elementwise_op(Tensor.cpu) +register_elementwise_op(Tensor.cuda) +register_elementwise_op(Tensor.deg2rad) +register_elementwise_op(Tensor.detach) +register_elementwise_op(Tensor.digamma) +register_elementwise_op(Tensor.double) +register_elementwise_op(Tensor.erf) +register_elementwise_op(Tensor.erfc) +register_elementwise_op(Tensor.erfinv) +register_elementwise_op(Tensor.exp) +register_elementwise_op(Tensor.expm1) +register_elementwise_op(Tensor.fix) +register_elementwise_op(Tensor.trunc) +register_elementwise_op(Tensor.float) +register_elementwise_op(Tensor.float_power) +register_elementwise_op(Tensor.floor) +register_elementwise_op(Tensor.frac) +register_elementwise_op(Tensor.half) +register_elementwise_op(Tensor.hardshrink) +register_elementwise_op(Tensor.heaviside) +register_elementwise_op(Tensor.i0) +register_elementwise_op(Tensor.int) +register_elementwise_op(Tensor.isfinite) +register_elementwise_op(Tensor.isinf) +register_elementwise_op(Tensor.isposinf) +register_elementwise_op(Tensor.isneginf) +register_elementwise_op(Tensor.isnan) +register_elementwise_op(Tensor.lgamma) +register_elementwise_op(Tensor.log) +register_elementwise_op(Tensor.log10) +register_elementwise_op(Tensor.log1p) +register_elementwise_op(Tensor.log2) +register_elementwise_op(Tensor.logical_not) +register_elementwise_op(Tensor.logit) +register_elementwise_op(Tensor.long) +register_elementwise_op(Tensor.nan_to_num) +register_elementwise_op(Tensor.neg) +register_elementwise_op(Tensor.negative) +register_elementwise_op(Tensor.positive) +register_elementwise_op(Tensor.pow) +register_elementwise_op(Tensor.rad2deg) +register_elementwise_op(Tensor.reciprocal) +register_elementwise_op(Tensor.round) +register_elementwise_op(Tensor.rsqrt) +register_elementwise_op(Tensor.short) +register_elementwise_op(Tensor.sigmoid) +register_elementwise_op(Tensor.sign) +register_elementwise_op(Tensor.signbit) +register_elementwise_op(Tensor.sgn) +register_elementwise_op(Tensor.sin) +register_elementwise_op(Tensor.sinc) +register_elementwise_op(Tensor.sinh) +register_elementwise_op(Tensor.asinh) +register_elementwise_op(Tensor.arcsinh) +register_elementwise_op(Tensor.sqrt) +register_elementwise_op(Tensor.square) +register_elementwise_op(Tensor.to) +register_elementwise_op(Tensor.tan) +register_elementwise_op(Tensor.tanh) +register_elementwise_op(Tensor.atanh) +register_elementwise_op(Tensor.arctanh) +register_elementwise_op(Tensor.type) +register_elementwise_op(Tensor.type_as) + +# torch OP +register_elementwise_op(torch.abs) +register_elementwise_op(torch.absolute) +register_elementwise_op(torch.acos) +register_elementwise_op(torch.arccos) +register_elementwise_op(torch.angle) +register_elementwise_op(torch.asin) +register_elementwise_op(torch.arcsin) +register_elementwise_op(torch.atan) +register_elementwise_op(torch.arctan) +register_elementwise_op(torch.all) +register_elementwise_op(torch.any) +register_elementwise_op(torch.bernoulli) +register_elementwise_op(torch.bitwise_not) +register_elementwise_op(torch.ceil) +register_elementwise_op(torch.clamp) +register_elementwise_op(torch.clamp_max) +register_elementwise_op(torch.clamp_min) +register_elementwise_op(torch.clip) register_elementwise_op(torch.clone) -register_elementwise_op(torch.Tensor.clone) -register_elementwise_op(torch.Tensor.detach) +register_elementwise_op(torch.copysign) +register_elementwise_op(torch.cos) +register_elementwise_op(torch.cosh) +register_elementwise_op(torch.acosh) +register_elementwise_op(torch.arccosh) +register_elementwise_op(torch.deg2rad) +register_elementwise_op(torch.digamma) +register_elementwise_op(torch.erf) +register_elementwise_op(torch.erfc) +register_elementwise_op(torch.erfinv) +register_elementwise_op(torch.exp) +register_elementwise_op(torch.expm1) +register_elementwise_op(torch.fix) +register_elementwise_op(torch.trunc) +register_elementwise_op(torch.float_power) +register_elementwise_op(torch.floor) +register_elementwise_op(torch.frac) +register_elementwise_op(torch.hardshrink) +register_elementwise_op(torch.heaviside) +register_elementwise_op(torch.i0) +register_elementwise_op(torch.isfinite) +register_elementwise_op(torch.isinf) +register_elementwise_op(torch.isposinf) +register_elementwise_op(torch.isneginf) +register_elementwise_op(torch.isnan) +register_elementwise_op(torch.lgamma) +register_elementwise_op(torch.log) +register_elementwise_op(torch.log10) +register_elementwise_op(torch.log1p) +register_elementwise_op(torch.log2) +register_elementwise_op(torch.logical_not) +register_elementwise_op(torch.logit) +register_elementwise_op(torch.nan_to_num) +register_elementwise_op(torch.neg) +register_elementwise_op(torch.negative) +register_elementwise_op(torch.positive) +register_elementwise_op(torch.pow) +register_elementwise_op(torch.rad2deg) +register_elementwise_op(torch.reciprocal) +register_elementwise_op(torch.round) +register_elementwise_op(torch.rsqrt) +register_elementwise_op(torch.sigmoid) +register_elementwise_op(torch.sign) +register_elementwise_op(torch.signbit) +register_elementwise_op(torch.sgn) +register_elementwise_op(torch.sin) +register_elementwise_op(torch.sinc) +register_elementwise_op(torch.sinh) +register_elementwise_op(torch.asinh) +register_elementwise_op(torch.arcsinh) +register_elementwise_op(torch.sqrt) +register_elementwise_op(torch.square) +register_elementwise_op(torch.tan) +register_elementwise_op(torch.tanh) +register_elementwise_op(torch.atanh) +register_elementwise_op(torch.arctanh) + +# nn.functional OP +register_elementwise_op(F.threshold) +register_elementwise_op(F.relu) +register_elementwise_op(F.hardtanh) +register_elementwise_op(F.hardswish) +register_elementwise_op(F.relu6) +register_elementwise_op(F.elu) +register_elementwise_op(F.selu) +register_elementwise_op(F.celu) +register_elementwise_op(F.leaky_relu) +register_elementwise_op(F.prelu) +register_elementwise_op(F.rrelu) +register_elementwise_op(F.gelu) +register_elementwise_op(F.logsigmoid) +register_elementwise_op(F.hardshrink) +register_elementwise_op(F.tanhshrink) +register_elementwise_op(F.softsign) +register_elementwise_op(F.softplus) +register_elementwise_op(F.softmin) +register_elementwise_op(F.softmax) +register_elementwise_op(F.softshrink) +register_elementwise_op(F.gumbel_softmax) +register_elementwise_op(F.log_softmax) +register_elementwise_op(F.tanh) +register_elementwise_op(F.sigmoid) +register_elementwise_op(F.hardsigmoid) +register_elementwise_op(F.silu) +register_elementwise_op(F.mish) +# TODO(ver217): dropout handles seed +register_elementwise_op(F.dropout) +register_elementwise_op(F.alpha_dropout) +register_elementwise_op(F.feature_alpha_dropout) diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 510dad108..5298f292d 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -1,8 +1,16 @@ import torch +import pytest +import colossalai +import torch.nn.functional as F +import torch.multiprocessing as mp +from functools import partial from colossalai.tensor import ColoTensor, ColoParameter from colossalai.utils import get_current_device from torch.nn import Parameter -import torch.nn.functional as F +from torch.distributed.distributed_c10d import _get_default_group +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.tensor import distspec, TensorSpec def test_layernorm(): @@ -26,8 +34,42 @@ def test_layernorm(): assert torch.allclose(ln_op.weight.grad, weight.grad) +def check_spec_eq(tensor, other): + assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) + for k in dir(tensor.spec.dist_spec): + if not k.startswith('__'): + assert hasattr(other.spec.dist_spec, k) + assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k) + + +def check_element_wise_ops(): + pg = _get_default_group() + t = torch.rand(2, 2) + x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.size()]))) + check_spec_eq(x, x.cuda()) + assert torch.equal(x.cuda(), t.cuda()) + check_spec_eq(x, torch.abs(x)) + assert torch.equal(torch.abs(x), torch.abs(t)) + check_spec_eq(x, F.sigmoid(x)) + assert torch.equal(F.sigmoid(x), F.sigmoid(t)) + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + check_element_wise_ops() + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [2]) +@rerun_if_address_is_in_use() +def test_element_wise_ops(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + def check_all(): test_layernorm() + test_element_wise_ops(2) if __name__ == '__main__':