mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
29 lines
1.0 KiB
29 lines
1.0 KiB
3 years ago
|
from copy import copy
|
||
3 years ago
|
import torch
|
||
3 years ago
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||
|
from colossalai.tensor import ColoTensor
|
||
3 years ago
|
from ._utils import GeneralTensor
|
||
3 years ago
|
|
||
|
|
||
|
def register_elementwise_op(op):
|
||
|
|
||
3 years ago
|
@colo_op_impl(op)
|
||
3 years ago
|
def elementwise_op(input_tensor: GeneralTensor, *args, **kwargs):
|
||
3 years ago
|
"""
|
||
|
Handles ``__torch_function__`` dispatch for the elementwise op such
|
||
|
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||
|
This method computes on either a normal tensor or a sharded tensor.
|
||
|
"""
|
||
3 years ago
|
output = op(input_tensor, *args, **kwargs)
|
||
|
if isinstance(input_tensor, ColoTensor):
|
||
|
spec = copy(input_tensor.spec)
|
||
|
return ColoTensor.from_torch_tensor(output, spec=spec)
|
||
|
return ColoTensor.from_torch_tensor(output)
|
||
3 years ago
|
|
||
|
|
||
|
register_elementwise_op(torch.nn.functional.gelu)
|
||
|
register_elementwise_op(torch.nn.functional.relu)
|
||
3 years ago
|
register_elementwise_op(torch.clone)
|
||
|
register_elementwise_op(torch.Tensor.clone)
|
||
|
register_elementwise_op(torch.Tensor.detach)
|