Browse Source

[tensor]fix test_linear (#826)

pull/717/merge
Ziyue Jiang 3 years ago committed by GitHub
parent
commit
8e6fdb4f29
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      colossalai/tensor/_ops/linear.py
  2. 11
      tests/test_tensor/test_op.py

5
colossalai/tensor/_ops/linear.py

@ -19,8 +19,9 @@ def colo_linear(types, args, kwargs, pg):
bias = None
else:
bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):

11
tests/test_tensor/test_op.py

@ -3,7 +3,6 @@ import torch
from colossalai.tensor import ColoTensor
from copy import deepcopy
def test_linear():
in_dim = 4
out_dim = 5
@ -45,7 +44,6 @@ def test_linear():
# torch.nn.init.uniform_(t)
# print(t)
def test_element_wise():
t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
@ -66,6 +64,11 @@ def test_lazy_init_tensor():
assert lazy_t._torch_tensor == None
assert lazy_t.torch_tensor().numel() == 6
if __name__ == '__main__':
def check_all():
test_linear()
test_element_wise()
test_no_wrap_op()
# test_element_wise()
test_lazy_init_tensor()
if __name__ == '__main__':
check_all()

Loading…
Cancel
Save