Browse Source

[hotfix[ fix colotensor.type() raise NotImplementedError (#1682)

pull/1683/head
jim 2 years ago committed by GitHub
parent
commit
e5ab6be72e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/nn/_ops/element_wise.py

2
colossalai/nn/_ops/element_wise.py

@ -18,6 +18,8 @@ def register_elementwise_op(op):
output = op(input_tensor, *args, **kwargs) output = op(input_tensor, *args, **kwargs)
if isinstance(input_tensor, ColoTensor): if isinstance(input_tensor, ColoTensor):
if isinstance(output, str):
return output
if not isinstance(output, torch.Tensor): if not isinstance(output, torch.Tensor):
raise NotImplementedError raise NotImplementedError
return ColoTensor.from_torch_tensor(output, return ColoTensor.from_torch_tensor(output,

Loading…
Cancel
Save