diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 4c9e72a92..12555f741 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -54,6 +54,13 @@ def test_element_wise(): assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) +# Test a function not wrapped by +def test_no_wrap_op(): + t_ref = torch.randn(3, 5) + t = ColoTensor(t_ref.clone()) + assert torch.sum(t) == torch.sum(t_ref) + + if __name__ == '__main__': - test_linear() + test_no_wrap_op() # test_element_wise()