2022-05-19 04:44:59 +00:00
|
|
|
from copy import copy
|
2022-04-21 03:42:37 +00:00
|
|
|
import torch
|
2022-04-21 06:15:48 +00:00
|
|
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
|
|
|
from colossalai.tensor import ColoTensor
|
2022-05-19 04:44:59 +00:00
|
|
|
from ._utils import GeneralTensor
|
2022-04-21 03:42:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
def register_elementwise_op(op):
|
|
|
|
|
2022-04-21 06:15:48 +00:00
|
|
|
@colo_op_impl(op)
|
2022-05-19 04:44:59 +00:00
|
|
|
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
|
2022-04-21 03:42:37 +00:00
|
|
|
"""
|
|
|
|
Handles ``__torch_function__`` dispatch for the elementwise op such
|
|
|
|
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
|
|
|
This method computes on either a normal tensor or a sharded tensor.
|
|
|
|
"""
|
2022-05-19 04:44:59 +00:00
|
|
|
output = op(input_tensor, *args, **kwargs)
|
|
|
|
if isinstance(input_tensor, ColoTensor):
|
|
|
|
spec = copy(input_tensor.spec)
|
|
|
|
return ColoTensor.from_torch_tensor(output, spec=spec)
|
|
|
|
return ColoTensor.from_torch_tensor(output)
|
2022-04-21 03:42:37 +00:00
|
|
|
|
|
|
|
|
|
|
|
register_elementwise_op(torch.nn.functional.gelu)
|
|
|
|
register_elementwise_op(torch.nn.functional.relu)
|
2022-05-19 04:44:59 +00:00
|
|
|
register_elementwise_op(torch.clone)
|
|
|
|
register_elementwise_op(torch.Tensor.clone)
|
|
|
|
register_elementwise_op(torch.Tensor.detach)
|