diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index c3c1421e7..462670e72 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -18,6 +18,8 @@ def register_elementwise_op(op): output = op(input_tensor, *args, **kwargs) if isinstance(input_tensor, ColoTensor): + if isinstance(output, str): + return output if not isinstance(output, torch.Tensor): raise NotImplementedError return ColoTensor.from_torch_tensor(output,