mirror of https://github.com/hpcaitech/ColossalAI
29 lines
1.0 KiB
Python
29 lines
1.0 KiB
Python
from copy import copy
|
|
import torch
|
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
from colossalai.tensor import ColoTensor
|
|
from ._utils import GeneralTensor
|
|
|
|
|
|
def register_elementwise_op(op):
|
|
|
|
@colo_op_impl(op)
|
|
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
|
|
"""
|
|
Handles ``__torch_function__`` dispatch for the elementwise op such
|
|
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
|
This method computes on either a normal tensor or a sharded tensor.
|
|
"""
|
|
output = op(input_tensor, *args, **kwargs)
|
|
if isinstance(input_tensor, ColoTensor):
|
|
spec = copy(input_tensor.spec)
|
|
return ColoTensor.from_torch_tensor(output, spec=spec)
|
|
return ColoTensor.from_torch_tensor(output)
|
|
|
|
|
|
register_elementwise_op(torch.nn.functional.gelu)
|
|
register_elementwise_op(torch.nn.functional.relu)
|
|
register_elementwise_op(torch.clone)
|
|
register_elementwise_op(torch.Tensor.clone)
|
|
register_elementwise_op(torch.Tensor.detach)
|