From a2d3266648bada1151229a4fda5ba042d42db5c0 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 22 Nov 2022 14:52:36 +0800 Subject: [PATCH] [hotfix] make Gemini work for conv DNN (#1998) --- colossalai/nn/_ops/element_wise.py | 31 ++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index 462670e72..db711be9a 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -1,9 +1,11 @@ import torch import torch.nn.functional as F from torch import Tensor -from colossalai.tensor.op_wrapper import colo_op_impl + from colossalai.tensor import ColoTensor, ColoTensorSpec -from ._utils import GeneralTensor +from colossalai.tensor.op_wrapper import colo_op_impl + +from ._utils import GeneralTensor, convert_to_colo_tensor def register_elementwise_op(op): @@ -15,16 +17,21 @@ def register_elementwise_op(op): as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. This method computes on either a normal tensor or a sharded tensor. """ - - output = op(input_tensor, *args, **kwargs) - if isinstance(input_tensor, ColoTensor): - if isinstance(output, str): - return output - if not isinstance(output, torch.Tensor): - raise NotImplementedError - return ColoTensor.from_torch_tensor(output, - spec=ColoTensorSpec(input_tensor.get_process_group(), - dist_attr=input_tensor.dist_spec)) + if 'inplace' in kwargs: + # TODO(jiaruifang) inplace will cause bugs + input_tensor = input_tensor.clone() + return op(input_tensor, *args, **kwargs) + else: + output = op(input_tensor, *args, **kwargs) + # return output + if isinstance(input_tensor, ColoTensor): + if isinstance(output, str): + return output + if not isinstance(output, torch.Tensor): + raise NotImplementedError + return ColoTensor.from_torch_tensor(output, + spec=ColoTensorSpec(input_tensor.get_process_group(), + dist_attr=input_tensor.dist_spec)) # Tensor op