[tensor] add more element-wise ops (#1155)

* add more element-wise ops

* update test_op

* polish unit test
pull/1158/head
ver217 2022-06-22 15:16:47 +08:00 committed by GitHub
parent e8c34eedfd
commit ae86151968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 248 additions and 6 deletions

View File

@ -1,5 +1,7 @@
from copy import copy
import torch
import torch.nn.functional as F
from torch import Tensor
from copy import copy
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from ._utils import GeneralTensor
@ -21,8 +23,206 @@ def register_elementwise_op(op):
return ColoTensor.from_torch_tensor(output)
register_elementwise_op(torch.nn.functional.gelu)
register_elementwise_op(torch.nn.functional.relu)
# Tensor op
register_elementwise_op(Tensor.abs)
register_elementwise_op(Tensor.absolute)
register_elementwise_op(Tensor.acos)
register_elementwise_op(Tensor.arccos)
register_elementwise_op(Tensor.angle)
register_elementwise_op(Tensor.asin)
register_elementwise_op(Tensor.arcsin)
register_elementwise_op(Tensor.atan)
register_elementwise_op(Tensor.arctan)
register_elementwise_op(Tensor.all)
register_elementwise_op(Tensor.any)
register_elementwise_op(Tensor.bernoulli)
register_elementwise_op(Tensor.bfloat16)
register_elementwise_op(Tensor.bitwise_not)
register_elementwise_op(Tensor.bool)
register_elementwise_op(Tensor.byte)
register_elementwise_op(Tensor.ceil)
register_elementwise_op(Tensor.char)
register_elementwise_op(Tensor.clamp)
register_elementwise_op(Tensor.clamp_max)
register_elementwise_op(Tensor.clamp_min)
register_elementwise_op(Tensor.clip)
register_elementwise_op(Tensor.clone)
register_elementwise_op(Tensor.contiguous)
register_elementwise_op(Tensor.copysign)
register_elementwise_op(Tensor.cos)
register_elementwise_op(Tensor.cosh)
register_elementwise_op(Tensor.acosh)
register_elementwise_op(Tensor.arccosh)
register_elementwise_op(Tensor.cpu)
register_elementwise_op(Tensor.cuda)
register_elementwise_op(Tensor.deg2rad)
register_elementwise_op(Tensor.detach)
register_elementwise_op(Tensor.digamma)
register_elementwise_op(Tensor.double)
register_elementwise_op(Tensor.erf)
register_elementwise_op(Tensor.erfc)
register_elementwise_op(Tensor.erfinv)
register_elementwise_op(Tensor.exp)
register_elementwise_op(Tensor.expm1)
register_elementwise_op(Tensor.fix)
register_elementwise_op(Tensor.trunc)
register_elementwise_op(Tensor.float)
register_elementwise_op(Tensor.float_power)
register_elementwise_op(Tensor.floor)
register_elementwise_op(Tensor.frac)
register_elementwise_op(Tensor.half)
register_elementwise_op(Tensor.hardshrink)
register_elementwise_op(Tensor.heaviside)
register_elementwise_op(Tensor.i0)
register_elementwise_op(Tensor.int)
register_elementwise_op(Tensor.isfinite)
register_elementwise_op(Tensor.isinf)
register_elementwise_op(Tensor.isposinf)
register_elementwise_op(Tensor.isneginf)
register_elementwise_op(Tensor.isnan)
register_elementwise_op(Tensor.lgamma)
register_elementwise_op(Tensor.log)
register_elementwise_op(Tensor.log10)
register_elementwise_op(Tensor.log1p)
register_elementwise_op(Tensor.log2)
register_elementwise_op(Tensor.logical_not)
register_elementwise_op(Tensor.logit)
register_elementwise_op(Tensor.long)
register_elementwise_op(Tensor.nan_to_num)
register_elementwise_op(Tensor.neg)
register_elementwise_op(Tensor.negative)
register_elementwise_op(Tensor.positive)
register_elementwise_op(Tensor.pow)
register_elementwise_op(Tensor.rad2deg)
register_elementwise_op(Tensor.reciprocal)
register_elementwise_op(Tensor.round)
register_elementwise_op(Tensor.rsqrt)
register_elementwise_op(Tensor.short)
register_elementwise_op(Tensor.sigmoid)
register_elementwise_op(Tensor.sign)
register_elementwise_op(Tensor.signbit)
register_elementwise_op(Tensor.sgn)
register_elementwise_op(Tensor.sin)
register_elementwise_op(Tensor.sinc)
register_elementwise_op(Tensor.sinh)
register_elementwise_op(Tensor.asinh)
register_elementwise_op(Tensor.arcsinh)
register_elementwise_op(Tensor.sqrt)
register_elementwise_op(Tensor.square)
register_elementwise_op(Tensor.to)
register_elementwise_op(Tensor.tan)
register_elementwise_op(Tensor.tanh)
register_elementwise_op(Tensor.atanh)
register_elementwise_op(Tensor.arctanh)
register_elementwise_op(Tensor.type)
register_elementwise_op(Tensor.type_as)
# torch OP
register_elementwise_op(torch.abs)
register_elementwise_op(torch.absolute)
register_elementwise_op(torch.acos)
register_elementwise_op(torch.arccos)
register_elementwise_op(torch.angle)
register_elementwise_op(torch.asin)
register_elementwise_op(torch.arcsin)
register_elementwise_op(torch.atan)
register_elementwise_op(torch.arctan)
register_elementwise_op(torch.all)
register_elementwise_op(torch.any)
register_elementwise_op(torch.bernoulli)
register_elementwise_op(torch.bitwise_not)
register_elementwise_op(torch.ceil)
register_elementwise_op(torch.clamp)
register_elementwise_op(torch.clamp_max)
register_elementwise_op(torch.clamp_min)
register_elementwise_op(torch.clip)
register_elementwise_op(torch.clone)
register_elementwise_op(torch.Tensor.clone)
register_elementwise_op(torch.Tensor.detach)
register_elementwise_op(torch.copysign)
register_elementwise_op(torch.cos)
register_elementwise_op(torch.cosh)
register_elementwise_op(torch.acosh)
register_elementwise_op(torch.arccosh)
register_elementwise_op(torch.deg2rad)
register_elementwise_op(torch.digamma)
register_elementwise_op(torch.erf)
register_elementwise_op(torch.erfc)
register_elementwise_op(torch.erfinv)
register_elementwise_op(torch.exp)
register_elementwise_op(torch.expm1)
register_elementwise_op(torch.fix)
register_elementwise_op(torch.trunc)
register_elementwise_op(torch.float_power)
register_elementwise_op(torch.floor)
register_elementwise_op(torch.frac)
register_elementwise_op(torch.hardshrink)
register_elementwise_op(torch.heaviside)
register_elementwise_op(torch.i0)
register_elementwise_op(torch.isfinite)
register_elementwise_op(torch.isinf)
register_elementwise_op(torch.isposinf)
register_elementwise_op(torch.isneginf)
register_elementwise_op(torch.isnan)
register_elementwise_op(torch.lgamma)
register_elementwise_op(torch.log)
register_elementwise_op(torch.log10)
register_elementwise_op(torch.log1p)
register_elementwise_op(torch.log2)
register_elementwise_op(torch.logical_not)
register_elementwise_op(torch.logit)
register_elementwise_op(torch.nan_to_num)
register_elementwise_op(torch.neg)
register_elementwise_op(torch.negative)
register_elementwise_op(torch.positive)
register_elementwise_op(torch.pow)
register_elementwise_op(torch.rad2deg)
register_elementwise_op(torch.reciprocal)
register_elementwise_op(torch.round)
register_elementwise_op(torch.rsqrt)
register_elementwise_op(torch.sigmoid)
register_elementwise_op(torch.sign)
register_elementwise_op(torch.signbit)
register_elementwise_op(torch.sgn)
register_elementwise_op(torch.sin)
register_elementwise_op(torch.sinc)
register_elementwise_op(torch.sinh)
register_elementwise_op(torch.asinh)
register_elementwise_op(torch.arcsinh)
register_elementwise_op(torch.sqrt)
register_elementwise_op(torch.square)
register_elementwise_op(torch.tan)
register_elementwise_op(torch.tanh)
register_elementwise_op(torch.atanh)
register_elementwise_op(torch.arctanh)
# nn.functional OP
register_elementwise_op(F.threshold)
register_elementwise_op(F.relu)
register_elementwise_op(F.hardtanh)
register_elementwise_op(F.hardswish)
register_elementwise_op(F.relu6)
register_elementwise_op(F.elu)
register_elementwise_op(F.selu)
register_elementwise_op(F.celu)
register_elementwise_op(F.leaky_relu)
register_elementwise_op(F.prelu)
register_elementwise_op(F.rrelu)
register_elementwise_op(F.gelu)
register_elementwise_op(F.logsigmoid)
register_elementwise_op(F.hardshrink)
register_elementwise_op(F.tanhshrink)
register_elementwise_op(F.softsign)
register_elementwise_op(F.softplus)
register_elementwise_op(F.softmin)
register_elementwise_op(F.softmax)
register_elementwise_op(F.softshrink)
register_elementwise_op(F.gumbel_softmax)
register_elementwise_op(F.log_softmax)
register_elementwise_op(F.tanh)
register_elementwise_op(F.sigmoid)
register_elementwise_op(F.hardsigmoid)
register_elementwise_op(F.silu)
register_elementwise_op(F.mish)
# TODO(ver217): dropout handles seed
register_elementwise_op(F.dropout)
register_elementwise_op(F.alpha_dropout)
register_elementwise_op(F.feature_alpha_dropout)

View File

@ -1,8 +1,16 @@
import torch
import pytest
import colossalai
import torch.nn.functional as F
import torch.multiprocessing as mp
from functools import partial
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.utils import get_current_device
from torch.nn import Parameter
import torch.nn.functional as F
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec
def test_layernorm():
@ -26,8 +34,42 @@ def test_layernorm():
assert torch.allclose(ln_op.weight.grad, weight.grad)
def check_spec_eq(tensor, other):
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
for k in dir(tensor.spec.dist_spec):
if not k.startswith('__'):
assert hasattr(other.spec.dist_spec, k)
assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k)
def check_element_wise_ops():
pg = _get_default_group()
t = torch.rand(2, 2)
x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.size()])))
check_spec_eq(x, x.cuda())
assert torch.equal(x.cuda(), t.cuda())
check_spec_eq(x, torch.abs(x))
assert torch.equal(torch.abs(x), torch.abs(t))
check_spec_eq(x, F.sigmoid(x))
assert torch.equal(F.sigmoid(x), F.sigmoid(t))
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_element_wise_ops()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_element_wise_ops(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
def check_all():
test_layernorm()
test_element_wise_ops(2)
if __name__ == '__main__':