mirror of https://github.com/hpcaitech/ColossalAI
[tensor]fix test_linear (#826)
parent
1a9e2c2dff
commit
8e6fdb4f29
|
@ -19,8 +19,9 @@ def colo_linear(types, args, kwargs, pg):
|
|||
bias = None
|
||||
else:
|
||||
bias = kwargs.get('bias', None)
|
||||
if isinstance(bias, ColoTensor):
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
if isinstance(bias, ColoTensor):
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
if isinstance(weight, ColoTensor):
|
||||
|
|
|
@ -3,7 +3,6 @@ import torch
|
|||
from colossalai.tensor import ColoTensor
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def test_linear():
|
||||
in_dim = 4
|
||||
out_dim = 5
|
||||
|
@ -45,7 +44,6 @@ def test_linear():
|
|||
# torch.nn.init.uniform_(t)
|
||||
# print(t)
|
||||
|
||||
|
||||
def test_element_wise():
|
||||
t_ref = torch.randn(3, 5)
|
||||
t = ColoTensor.init_from_torch_tensor(t_ref.clone())
|
||||
|
@ -66,6 +64,11 @@ def test_lazy_init_tensor():
|
|||
assert lazy_t._torch_tensor == None
|
||||
assert lazy_t.torch_tensor().numel() == 6
|
||||
|
||||
if __name__ == '__main__':
|
||||
def check_all():
|
||||
test_linear()
|
||||
test_element_wise()
|
||||
test_no_wrap_op()
|
||||
# test_element_wise()
|
||||
test_lazy_init_tensor()
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_all()
|
||||
|
|
Loading…
Reference in New Issue