diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index e75f18609..d8bc338a5 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -19,8 +19,9 @@ def colo_linear(types, args, kwargs, pg): bias = None else: bias = kwargs.get('bias', None) - if isinstance(bias, ColoTensor): - bias = bias.torch_tensor() + + if isinstance(bias, ColoTensor): + bias = bias.torch_tensor() # Add communication logic before and after linear call. if isinstance(weight, ColoTensor): diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index c45dca8da..6cd45df44 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -3,7 +3,6 @@ import torch from colossalai.tensor import ColoTensor from copy import deepcopy - def test_linear(): in_dim = 4 out_dim = 5 @@ -45,7 +44,6 @@ def test_linear(): # torch.nn.init.uniform_(t) # print(t) - def test_element_wise(): t_ref = torch.randn(3, 5) t = ColoTensor.init_from_torch_tensor(t_ref.clone()) @@ -66,6 +64,11 @@ def test_lazy_init_tensor(): assert lazy_t._torch_tensor == None assert lazy_t.torch_tensor().numel() == 6 -if __name__ == '__main__': +def check_all(): + test_linear() + test_element_wise() test_no_wrap_op() - # test_element_wise() + test_lazy_init_tensor() + +if __name__ == '__main__': + check_all()