mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] make Gemini work for conv DNN (#1998)
parent
155891113e
commit
a2d3266648
|
@ -1,9 +1,11 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||
from ._utils import GeneralTensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
|
||||
|
||||
def register_elementwise_op(op):
|
||||
|
@ -15,16 +17,21 @@ def register_elementwise_op(op):
|
|||
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
|
||||
This method computes on either a normal tensor or a sharded tensor.
|
||||
"""
|
||||
|
||||
output = op(input_tensor, *args, **kwargs)
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise NotImplementedError
|
||||
return ColoTensor.from_torch_tensor(output,
|
||||
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
||||
dist_attr=input_tensor.dist_spec))
|
||||
if 'inplace' in kwargs:
|
||||
# TODO(jiaruifang) inplace will cause bugs
|
||||
input_tensor = input_tensor.clone()
|
||||
return op(input_tensor, *args, **kwargs)
|
||||
else:
|
||||
output = op(input_tensor, *args, **kwargs)
|
||||
# return output
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise NotImplementedError
|
||||
return ColoTensor.from_torch_tensor(output,
|
||||
spec=ColoTensorSpec(input_tensor.get_process_group(),
|
||||
dist_attr=input_tensor.dist_spec))
|
||||
|
||||
|
||||
# Tensor op
|
||||
|
|
Loading…
Reference in New Issue