[hotfix] make Gemini work for conv DNN (#1998)

pull/1999/head
Jiarui Fang 2022-11-22 14:52:36 +08:00 committed by GitHub
parent 155891113e
commit a2d3266648
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 12 deletions

View File

@ -1,9 +1,11 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, ColoTensorSpec
from ._utils import GeneralTensor
from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
def register_elementwise_op(op):
@ -15,16 +17,21 @@ def register_elementwise_op(op):
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor.
"""
output = op(input_tensor, *args, **kwargs)
if isinstance(input_tensor, ColoTensor):
if isinstance(output, str):
return output
if not isinstance(output, torch.Tensor):
raise NotImplementedError
return ColoTensor.from_torch_tensor(output,
spec=ColoTensorSpec(input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
if 'inplace' in kwargs:
# TODO(jiaruifang) inplace will cause bugs
input_tensor = input_tensor.clone()
return op(input_tensor, *args, **kwargs)
else:
output = op(input_tensor, *args, **kwargs)
# return output
if isinstance(input_tensor, ColoTensor):
if isinstance(output, str):
return output
if not isinstance(output, torch.Tensor):
raise NotImplementedError
return ColoTensor.from_torch_tensor(output,
spec=ColoTensorSpec(input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
# Tensor op