mirror of https://github.com/hpcaitech/ColossalAI
251 lines
9.1 KiB
Python
251 lines
9.1 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
|
|
from ._utils import GeneralTensor, convert_to_colo_tensor
|
|
|
|
|
|
def register_elementwise_op(op):
|
|
|
|
@colo_op_impl(op)
|
|
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
|
|
"""
|
|
Handles ``__torch_function__`` dispatch for the elementwise op such
|
|
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
|
This method computes on either a normal tensor or a sharded tensor.
|
|
"""
|
|
if 'inplace' in kwargs:
|
|
# TODO(jiaruifang) inplace will cause bugs
|
|
input_tensor = input_tensor.clone()
|
|
return op(input_tensor, *args, **kwargs)
|
|
else:
|
|
output = op(input_tensor, *args, **kwargs)
|
|
# return output
|
|
if isinstance(input_tensor, ColoTensor):
|
|
if isinstance(output, str):
|
|
return output
|
|
if not isinstance(output, torch.Tensor):
|
|
raise NotImplementedError
|
|
return ColoTensor.from_torch_tensor(output,
|
|
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
|
dist_attr=input_tensor.dist_spec))
|
|
|
|
|
|
# @colo_op_impl(torch.relu_)
|
|
# def elementwise_op(input_tensor):
|
|
# torch.relu_(input_tensor.data)
|
|
# return input_tensor
|
|
|
|
# @colo_op_impl(Tensor.add_)
|
|
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
|
|
# input_tensor = input_tensor.data.add_(*args, **kwargs)
|
|
# return input_tensor
|
|
|
|
# Tensor op
|
|
register_elementwise_op(Tensor.abs)
|
|
register_elementwise_op(Tensor.absolute)
|
|
register_elementwise_op(Tensor.acos)
|
|
register_elementwise_op(Tensor.arccos)
|
|
register_elementwise_op(Tensor.angle)
|
|
register_elementwise_op(Tensor.asin)
|
|
register_elementwise_op(Tensor.arcsin)
|
|
register_elementwise_op(Tensor.atan)
|
|
register_elementwise_op(Tensor.arctan)
|
|
register_elementwise_op(Tensor.all)
|
|
register_elementwise_op(Tensor.any)
|
|
register_elementwise_op(Tensor.bernoulli)
|
|
register_elementwise_op(Tensor.bfloat16)
|
|
register_elementwise_op(Tensor.bitwise_not)
|
|
register_elementwise_op(Tensor.bool)
|
|
register_elementwise_op(Tensor.byte)
|
|
register_elementwise_op(Tensor.ceil)
|
|
register_elementwise_op(Tensor.char)
|
|
register_elementwise_op(Tensor.clamp)
|
|
register_elementwise_op(Tensor.clamp_max)
|
|
register_elementwise_op(Tensor.clamp_min)
|
|
register_elementwise_op(Tensor.clip)
|
|
register_elementwise_op(Tensor.clone)
|
|
register_elementwise_op(Tensor.contiguous)
|
|
register_elementwise_op(Tensor.copysign)
|
|
register_elementwise_op(Tensor.cos)
|
|
register_elementwise_op(Tensor.cosh)
|
|
register_elementwise_op(Tensor.acosh)
|
|
register_elementwise_op(Tensor.arccosh)
|
|
register_elementwise_op(Tensor.cpu)
|
|
register_elementwise_op(Tensor.cuda)
|
|
register_elementwise_op(Tensor.deg2rad)
|
|
register_elementwise_op(Tensor.detach)
|
|
register_elementwise_op(Tensor.digamma)
|
|
register_elementwise_op(Tensor.double)
|
|
register_elementwise_op(Tensor.erf)
|
|
register_elementwise_op(Tensor.erfc)
|
|
register_elementwise_op(Tensor.erfinv)
|
|
register_elementwise_op(Tensor.exp)
|
|
register_elementwise_op(Tensor.expm1)
|
|
register_elementwise_op(Tensor.fix)
|
|
register_elementwise_op(Tensor.trunc)
|
|
register_elementwise_op(Tensor.float)
|
|
register_elementwise_op(Tensor.float_power)
|
|
register_elementwise_op(Tensor.floor)
|
|
register_elementwise_op(Tensor.frac)
|
|
register_elementwise_op(Tensor.half)
|
|
register_elementwise_op(Tensor.hardshrink)
|
|
register_elementwise_op(Tensor.heaviside)
|
|
register_elementwise_op(Tensor.i0)
|
|
register_elementwise_op(Tensor.int)
|
|
register_elementwise_op(Tensor.isfinite)
|
|
register_elementwise_op(Tensor.isinf)
|
|
register_elementwise_op(Tensor.isposinf)
|
|
register_elementwise_op(Tensor.isneginf)
|
|
register_elementwise_op(Tensor.isnan)
|
|
register_elementwise_op(Tensor.lgamma)
|
|
register_elementwise_op(Tensor.log)
|
|
register_elementwise_op(Tensor.log10)
|
|
register_elementwise_op(Tensor.log1p)
|
|
register_elementwise_op(Tensor.log2)
|
|
register_elementwise_op(Tensor.logical_not)
|
|
register_elementwise_op(Tensor.logit)
|
|
register_elementwise_op(Tensor.long)
|
|
register_elementwise_op(Tensor.nan_to_num)
|
|
register_elementwise_op(Tensor.neg)
|
|
register_elementwise_op(Tensor.negative)
|
|
register_elementwise_op(Tensor.positive)
|
|
register_elementwise_op(Tensor.pow)
|
|
register_elementwise_op(Tensor.rad2deg)
|
|
register_elementwise_op(Tensor.reciprocal)
|
|
register_elementwise_op(Tensor.round)
|
|
register_elementwise_op(Tensor.rsqrt)
|
|
register_elementwise_op(Tensor.short)
|
|
register_elementwise_op(Tensor.sigmoid)
|
|
register_elementwise_op(Tensor.sign)
|
|
register_elementwise_op(Tensor.signbit)
|
|
register_elementwise_op(Tensor.sgn)
|
|
register_elementwise_op(Tensor.sin)
|
|
register_elementwise_op(Tensor.sinc)
|
|
register_elementwise_op(Tensor.sinh)
|
|
register_elementwise_op(Tensor.asinh)
|
|
register_elementwise_op(Tensor.arcsinh)
|
|
register_elementwise_op(Tensor.sqrt)
|
|
register_elementwise_op(Tensor.square)
|
|
register_elementwise_op(Tensor.to)
|
|
register_elementwise_op(Tensor.tan)
|
|
register_elementwise_op(Tensor.tanh)
|
|
register_elementwise_op(Tensor.atanh)
|
|
register_elementwise_op(Tensor.arctanh)
|
|
register_elementwise_op(Tensor.type)
|
|
register_elementwise_op(Tensor.type_as)
|
|
|
|
# torch OP
|
|
register_elementwise_op(torch.abs)
|
|
register_elementwise_op(torch.absolute)
|
|
register_elementwise_op(torch.acos)
|
|
register_elementwise_op(torch.arccos)
|
|
register_elementwise_op(torch.angle)
|
|
register_elementwise_op(torch.asin)
|
|
register_elementwise_op(torch.arcsin)
|
|
register_elementwise_op(torch.atan)
|
|
register_elementwise_op(torch.arctan)
|
|
register_elementwise_op(torch.all)
|
|
register_elementwise_op(torch.any)
|
|
register_elementwise_op(torch.bernoulli)
|
|
register_elementwise_op(torch.bitwise_not)
|
|
register_elementwise_op(torch.ceil)
|
|
register_elementwise_op(torch.clamp)
|
|
register_elementwise_op(torch.clamp_max)
|
|
register_elementwise_op(torch.clamp_min)
|
|
register_elementwise_op(torch.clip)
|
|
register_elementwise_op(torch.clone)
|
|
register_elementwise_op(torch.copysign)
|
|
register_elementwise_op(torch.cos)
|
|
register_elementwise_op(torch.cosh)
|
|
register_elementwise_op(torch.acosh)
|
|
register_elementwise_op(torch.arccosh)
|
|
register_elementwise_op(torch.deg2rad)
|
|
register_elementwise_op(torch.digamma)
|
|
register_elementwise_op(torch.erf)
|
|
register_elementwise_op(torch.erfc)
|
|
register_elementwise_op(torch.erfinv)
|
|
register_elementwise_op(torch.exp)
|
|
register_elementwise_op(torch.expm1)
|
|
register_elementwise_op(torch.fix)
|
|
register_elementwise_op(torch.trunc)
|
|
register_elementwise_op(torch.float_power)
|
|
register_elementwise_op(torch.floor)
|
|
register_elementwise_op(torch.frac)
|
|
register_elementwise_op(torch.hardshrink)
|
|
register_elementwise_op(torch.heaviside)
|
|
register_elementwise_op(torch.i0)
|
|
register_elementwise_op(torch.isfinite)
|
|
register_elementwise_op(torch.isinf)
|
|
register_elementwise_op(torch.isposinf)
|
|
register_elementwise_op(torch.isneginf)
|
|
register_elementwise_op(torch.isnan)
|
|
register_elementwise_op(torch.lgamma)
|
|
register_elementwise_op(torch.log)
|
|
register_elementwise_op(torch.log10)
|
|
register_elementwise_op(torch.log1p)
|
|
register_elementwise_op(torch.log2)
|
|
register_elementwise_op(torch.logical_not)
|
|
register_elementwise_op(torch.logit)
|
|
register_elementwise_op(torch.nan_to_num)
|
|
register_elementwise_op(torch.neg)
|
|
register_elementwise_op(torch.negative)
|
|
register_elementwise_op(torch.positive)
|
|
register_elementwise_op(torch.pow)
|
|
register_elementwise_op(torch.rad2deg)
|
|
register_elementwise_op(torch.reciprocal)
|
|
register_elementwise_op(torch.round)
|
|
register_elementwise_op(torch.rsqrt)
|
|
register_elementwise_op(torch.sigmoid)
|
|
register_elementwise_op(torch.sign)
|
|
register_elementwise_op(torch.signbit)
|
|
register_elementwise_op(torch.sgn)
|
|
register_elementwise_op(torch.sin)
|
|
register_elementwise_op(torch.sinc)
|
|
register_elementwise_op(torch.sinh)
|
|
register_elementwise_op(torch.asinh)
|
|
register_elementwise_op(torch.arcsinh)
|
|
register_elementwise_op(torch.sqrt)
|
|
register_elementwise_op(torch.square)
|
|
register_elementwise_op(torch.tan)
|
|
register_elementwise_op(torch.tanh)
|
|
register_elementwise_op(torch.atanh)
|
|
register_elementwise_op(torch.arctanh)
|
|
register_elementwise_op(torch.zeros_like)
|
|
|
|
# nn.functional OP
|
|
register_elementwise_op(F.threshold)
|
|
register_elementwise_op(F.relu)
|
|
register_elementwise_op(F.hardtanh)
|
|
register_elementwise_op(F.hardswish)
|
|
register_elementwise_op(F.relu6)
|
|
register_elementwise_op(F.elu)
|
|
register_elementwise_op(F.selu)
|
|
register_elementwise_op(F.celu)
|
|
register_elementwise_op(F.leaky_relu)
|
|
register_elementwise_op(F.prelu)
|
|
register_elementwise_op(F.rrelu)
|
|
register_elementwise_op(F.gelu)
|
|
register_elementwise_op(F.logsigmoid)
|
|
register_elementwise_op(F.hardshrink)
|
|
register_elementwise_op(F.tanhshrink)
|
|
register_elementwise_op(F.softsign)
|
|
register_elementwise_op(F.softplus)
|
|
register_elementwise_op(F.softmin)
|
|
register_elementwise_op(F.softmax)
|
|
register_elementwise_op(F.softshrink)
|
|
register_elementwise_op(F.gumbel_softmax)
|
|
register_elementwise_op(F.log_softmax)
|
|
register_elementwise_op(F.tanh)
|
|
register_elementwise_op(F.sigmoid)
|
|
register_elementwise_op(F.hardsigmoid)
|
|
register_elementwise_op(F.silu)
|
|
register_elementwise_op(F.mish)
|
|
# TODO(ver217): dropout handles seed
|
|
register_elementwise_op(F.dropout)
|
|
register_elementwise_op(F.alpha_dropout)
|
|
register_elementwise_op(F.feature_alpha_dropout)
|